xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/gradients.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/gradients.h"
16 
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
20 #include "tensorflow/c/eager/gradients_internal.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 namespace gradients {
27 namespace {
28 
29 // TODO(b/172558015): Using the pointer address as the identifier for the tensor
30 // may lead to collisions. Introduce another way to get a unique id for this
31 // tensor.
ToId(const AbstractTensorHandle * t)32 int64_t ToId(const AbstractTensorHandle* t) {
33   return static_cast<int64_t>(reinterpret_cast<uintptr_t>(t));
34 }
35 
ZerosLike(AbstractContext * ctx,AbstractTensorHandle * t,AbstractTensorHandle ** result)36 Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
37                  AbstractTensorHandle** result) {
38   AbstractOperationPtr op(ctx->CreateOperation());
39   TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
40   if (isa<tracing::TracingOperation>(op.get())) {
41     TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
42         absl::StrCat("ZerosLike", ToId(t)).c_str()));
43   }
44   TF_RETURN_IF_ERROR(op->AddInput(t));
45   int num_outputs = 1;
46   std::vector<AbstractTensorHandle*> outputs(num_outputs);
47   TF_RETURN_IF_ERROR(
48       op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
49   *result = outputs[0];
50   return OkStatus();
51 }
52 }  // namespace
53 
Register(const string & op_name,GradientFunctionFactory gradient_function_factory)54 Status GradientRegistry::Register(
55     const string& op_name, GradientFunctionFactory gradient_function_factory) {
56   auto iter = registry_.find(op_name);
57   if (iter != registry_.end()) {
58     const string error_msg = "Gradient already exists for op: " + op_name + ".";
59     return errors::AlreadyExists(error_msg);
60   }
61   registry_.insert({op_name, gradient_function_factory});
62   return OkStatus();
63 }
Lookup(const ForwardOperation & op,std::unique_ptr<GradientFunction> * gradient_function) const64 Status GradientRegistry::Lookup(
65     const ForwardOperation& op,
66     std::unique_ptr<GradientFunction>* gradient_function) const {
67   auto iter = registry_.find(op.op_name);
68   if (iter == registry_.end()) {
69     const string error_msg = "No gradient defined for op: " + op.op_name + ".";
70     return errors::NotFound(error_msg);
71   }
72   gradient_function->reset(iter->second(op));
73   return OkStatus();
74 }
75 
TapeTensor(AbstractTensorHandle * handle)76 TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
77   handle_->Ref();
78 }
TapeTensor(const TapeTensor & other)79 TapeTensor::TapeTensor(const TapeTensor& other) {
80   handle_ = other.handle_;
81   handle_->Ref();
82 }
~TapeTensor()83 TapeTensor::~TapeTensor() { handle_->Unref(); }
84 
GetID() const85 int64_t TapeTensor::GetID() const { return ToId(handle_); }
86 
GetDType() const87 tensorflow::DataType TapeTensor::GetDType() const {
88   return handle_->DataType();
89 }
GetHandle() const90 AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
91 
ZerosLike() const92 AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
93 
94 class TapeVSpace
95     : public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
96  public:
TapeVSpace(AbstractContext * ctx)97   explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace()98   ~TapeVSpace() override {}
99 
100   // Returns the number of elements in the gradient tensor.
101   int64_t NumElements(AbstractTensorHandle* tensor) const override;
102 
103   // Consumes references to the tensors in the gradient_tensors list and returns
104   // a tensor with the result.
105   AbstractTensorHandle* AggregateGradients(
106       gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
107 
108   // Calls the passed-in backward function.
109   // op_type is the op's name provided in RecordOperation.
110   Status CallBackwardFunction(
111       const string& op_type, GradientFunction* gradient_function,
112       const std::vector<int64_t>& unneeded_gradients,
113       gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
114       absl::Span<AbstractTensorHandle*> result) const override;
115 
116   // Builds a tensor filled with ones with the same shape and dtype as `t`.
117   Status BuildOnesLike(const TapeTensor& t,
118                        AbstractTensorHandle** result) const override;
119 
120   // Looks up the ID of a Gradient.
121   int64_t TensorId(AbstractTensorHandle* tensor) const override;
122 
123   // Converts a Gradient to a TapeTensor.
124   TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
125 
126   void MarkAsResult(AbstractTensorHandle* gradient) const override;
127 
128   void DeleteGradient(AbstractTensorHandle* gradient) const override;
129 
130  private:
131   // The context where the aggregation op `Add` is to be created.
132   AbstractContext* ctx_;
133 };
134 
135 // Returns the number of elements in the gradient tensor.
NumElements(AbstractTensorHandle * tensor) const136 int64_t TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
137   // TODO(srbs): It seems like this is used only for performance optimization
138   // and not for correctness. The only downside of keeping this 1 seems to be
139   // that the gradient accumulation is unbounded and we will never
140   // aggressively aggregate accumulated gradients to recover memory.
141   // Revisit and fix.
142   return 1;
143 }
144 
145 // Consumes references to the tensors in the gradient_tensors list and returns
146 // a tensor with the result.
AggregateGradients(gtl::ArraySlice<AbstractTensorHandle * > gradient_tensors) const147 AbstractTensorHandle* TapeVSpace::AggregateGradients(
148     gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
149   if (gradient_tensors.size() == 1) {
150     return gradient_tensors[0];
151   }
152 
153   AbstractOperationPtr op(ctx_->CreateOperation());
154   Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
155   if (!s.ok()) {
156     return nullptr;
157   }
158   s = op->AddInputList(gradient_tensors);
159   if (!s.ok()) {
160     return nullptr;
161   }
162 
163   int num_outputs = 1;
164   std::vector<AbstractTensorHandle*> outputs(num_outputs);
165   s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
166   if (!s.ok()) {
167     return nullptr;
168   }
169   return outputs[0];
170 }
171 
172 // Calls the passed-in backward function.
173 // op_type is the op's name provided in RecordOperation.
CallBackwardFunction(const string & op_type,GradientFunction * gradient_function,const std::vector<int64_t> & unneeded_gradients,gtl::ArraySlice<AbstractTensorHandle * > output_gradients,absl::Span<AbstractTensorHandle * > result) const174 Status TapeVSpace::CallBackwardFunction(
175     const string& op_type, GradientFunction* gradient_function,
176     const std::vector<int64_t>& unneeded_gradients,
177     gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
178     absl::Span<AbstractTensorHandle*> result) const {
179   if (gradient_function == nullptr) {
180     return errors::InvalidArgument(
181         "Provided null gradient_function for '", op_type, "'.\n",
182         "If the intent is to treat this op as non-differentiable consider "
183         "using RegisterNotDifferentiable or "
184         "NotDifferentiableGradientFunction.");
185   }
186   return gradient_function->Compute(ctx_, output_gradients, result);
187 }
188 
BuildOnesLike(const TapeTensor & t,AbstractTensorHandle ** result) const189 Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
190                                  AbstractTensorHandle** result) const {
191   AbstractOperationPtr op(ctx_->CreateOperation());
192   TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
193   if (isa<tracing::TracingOperation>(op.get())) {
194     TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
195         absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
196   }
197   TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
198   int num_outputs = 1;
199   std::vector<AbstractTensorHandle*> outputs(num_outputs);
200   TF_RETURN_IF_ERROR(
201       op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
202   *result = outputs[0];
203   return OkStatus();
204 }
205 
206 // Looks up the ID of a Gradient.
TensorId(AbstractTensorHandle * tensor) const207 int64_t TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
208   return ToId(tensor);
209 }
210 
211 // Converts a Gradient to a TapeTensor.
TapeTensorFromGradient(AbstractTensorHandle * g) const212 TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
213   return TapeTensor(g);
214 }
215 
MarkAsResult(AbstractTensorHandle * gradient) const216 void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
217 
DeleteGradient(AbstractTensorHandle * gradient) const218 void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
219   gradient->Unref();
220 }
221 
Watch(const AbstractTensorHandle * t)222 void Tape::Watch(const AbstractTensorHandle* t) {
223   GradientTape::Watch(ToId(t));
224 }
RecordOperation(absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * const> outputs,GradientFunction * gradient_function,const string & op_name)225 void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
226                            absl::Span<AbstractTensorHandle* const> outputs,
227                            GradientFunction* gradient_function,
228                            const string& op_name) {
229   std::vector<int64_t> input_ids(inputs.size());
230   std::vector<tensorflow::DataType> input_dtypes(inputs.size());
231   for (int i = 0; i < inputs.size(); i++) {
232     input_ids[i] = ToId(inputs[i]);
233     input_dtypes[i] = inputs[i]->DataType();
234   }
235   std::vector<TapeTensor> tape_tensors;
236   tape_tensors.reserve(outputs.size());
237   for (auto t : outputs) {
238     tape_tensors.push_back(TapeTensor(t));
239   }
240   GradientTape::RecordOperation(
241       op_name, tape_tensors, input_ids, input_dtypes,
242       [gradient_function]() -> GradientFunction* { return gradient_function; },
243       [](GradientFunction* ptr) {
244         if (ptr) {
245           delete ptr;
246         }
247       });
248 }
ShouldRecord(absl::Span<const AbstractTensorHandle * const> tensors) const249 bool Tape::ShouldRecord(
250     absl::Span<const AbstractTensorHandle* const> tensors) const {
251   std::vector<int64_t> tensor_ids(tensors.size());
252   std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
253   for (int i = 0; i < tensors.size(); i++) {
254     tensor_ids[i] = ToId(tensors[i]);
255     tensor_dtypes[i] = tensors[i]->DataType();
256   }
257   return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
258 }
DeleteTrace(const AbstractTensorHandle * t)259 void Tape::DeleteTrace(const AbstractTensorHandle* t) {
260   GradientTape::DeleteTrace(ToId(t));
261 }
262 
MakeTensorIDList(absl::Span<AbstractTensorHandle * const> tensors)263 std::vector<int64_t> MakeTensorIDList(
264     absl::Span<AbstractTensorHandle* const> tensors) {
265   std::vector<int64_t> ids(tensors.size());
266   for (int i = 0; i < tensors.size(); i++) {
267     ids[i] = ToId(tensors[i]);
268   }
269   return ids;
270 }
271 
ComputeGradient(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> targets,absl::Span<AbstractTensorHandle * const> sources,absl::Span<AbstractTensorHandle * const> output_gradients,absl::Span<AbstractTensorHandle * > result)272 Status Tape::ComputeGradient(
273     AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
274     absl::Span<AbstractTensorHandle* const> sources,
275     absl::Span<AbstractTensorHandle* const> output_gradients,
276     absl::Span<AbstractTensorHandle*> result) {
277   TapeVSpace vspace(ctx);
278   std::vector<int64_t> target_tensor_ids = MakeTensorIDList(targets);
279   std::vector<int64_t> source_tensor_ids = MakeTensorIDList(sources);
280   tensorflow::gtl::FlatSet<int64_t> sources_set(source_tensor_ids.begin(),
281                                                 source_tensor_ids.end());
282   std::unordered_map<int64_t, TapeTensor> sources_that_are_targets;
283   for (int i = 0; i < target_tensor_ids.size(); ++i) {
284     int64_t target_id = target_tensor_ids[i];
285     if (sources_set.find(target_id) != sources_set.end()) {
286       auto tensor = targets[i];
287       sources_that_are_targets.insert(
288           std::make_pair(target_id, TapeTensor(tensor)));
289     }
290   }
291 
292   TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
293       vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
294       output_gradients, result, /*build_default_zeros_grads*/ false));
295   return OkStatus();
296 }
297 
298 // Helper functions which delegate to `AbstractOperation`, update
299 // the state of the ForwardOperation and call the tape as appropriate.
300 // These APIs are mainly to facilitate testing and are subject to change.
301 namespace internal {
Reset(AbstractOperation * op_,const char * op,const char * raw_device_name,ForwardOperation * forward_op_)302 Status Reset(AbstractOperation* op_, const char* op,
303              const char* raw_device_name, ForwardOperation* forward_op_) {
304   forward_op_->op_name = op;
305   forward_op_->attrs.Reset(op);
306   return op_->Reset(op, raw_device_name);
307 }
AddInput(AbstractOperation * op_,AbstractTensorHandle * input,ForwardOperation * forward_op_)308 Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
309                 ForwardOperation* forward_op_) {
310   TF_RETURN_IF_ERROR(op_->AddInput(input));
311   forward_op_->inputs.push_back(input);
312   return OkStatus();
313 }
AddInputList(AbstractOperation * op_,absl::Span<AbstractTensorHandle * const> inputs,ForwardOperation * forward_op_)314 Status AddInputList(AbstractOperation* op_,
315                     absl::Span<AbstractTensorHandle* const> inputs,
316                     ForwardOperation* forward_op_) {
317   TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
318   for (auto input : inputs) {
319     forward_op_->inputs.push_back(input);
320   }
321   return OkStatus();
322 }
323 
SetAttrString(AbstractOperation * op_,const char * attr_name,const char * data,size_t length,ForwardOperation * forward_op_)324 Status SetAttrString(AbstractOperation* op_, const char* attr_name,
325                      const char* data, size_t length,
326                      ForwardOperation* forward_op_) {
327   forward_op_->attrs.Set(attr_name, StringPiece(data, length));
328   return op_->SetAttrString(attr_name, data, length);
329 }
SetAttrInt(AbstractOperation * op_,const char * attr_name,int64_t value,ForwardOperation * forward_op_)330 Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
331                   ForwardOperation* forward_op_) {
332   forward_op_->attrs.Set(attr_name, static_cast<int64_t>(value));
333   return op_->SetAttrInt(attr_name, value);
334 }
SetAttrFloat(AbstractOperation * op_,const char * attr_name,float value,ForwardOperation * forward_op_)335 Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
336                     ForwardOperation* forward_op_) {
337   forward_op_->attrs.Set(attr_name, value);
338   return op_->SetAttrFloat(attr_name, value);
339 }
SetAttrBool(AbstractOperation * op_,const char * attr_name,bool value,ForwardOperation * forward_op_)340 Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
341                    ForwardOperation* forward_op_) {
342   forward_op_->attrs.Set(attr_name, value);
343   return op_->SetAttrBool(attr_name, value);
344 }
SetAttrType(AbstractOperation * op_,const char * attr_name,DataType value,ForwardOperation * forward_op_)345 Status SetAttrType(AbstractOperation* op_, const char* attr_name,
346                    DataType value, ForwardOperation* forward_op_) {
347   forward_op_->attrs.Set(attr_name, value);
348   return op_->SetAttrType(attr_name, value);
349 }
SetAttrShape(AbstractOperation * op_,const char * attr_name,const int64_t * dims,const int num_dims,ForwardOperation * forward_op_)350 Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
351                     const int64_t* dims, const int num_dims,
352                     ForwardOperation* forward_op_) {
353   if (num_dims > TensorShape::MaxDimensions()) {
354     return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
355                                    num_dims,
356                                    " dimensions which is over the limit of ",
357                                    TensorShape::MaxDimensions(), ".");
358   }
359   TensorShapeProto proto;
360   if (num_dims < 0) {
361     proto.set_unknown_rank(true);
362   } else {
363     for (int d = 0; d < num_dims; ++d) {
364       proto.add_dim()->set_size(dims[d]);
365     }
366   }
367 
368   forward_op_->attrs.Set(attr_name, proto);
369   return op_->SetAttrShape(attr_name, dims, num_dims);
370 }
SetAttrFunction(AbstractOperation * op_,const char * attr_name,const AbstractOperation * value,ForwardOperation * forward_op_)371 Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
372                        const AbstractOperation* value,
373                        ForwardOperation* forward_op_) {
374   return tensorflow::errors::Unimplemented(
375       "SetAttrFunction has not been implemented yet.");
376 }
SetAttrFunctionName(AbstractOperation * op_,const char * attr_name,const char * value,size_t length,ForwardOperation * forward_op_)377 Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
378                            const char* value, size_t length,
379                            ForwardOperation* forward_op_) {
380   return tensorflow::errors::Unimplemented(
381       "SetAttrFunctionName has not been implemented "
382       "yet.");
383 }
SetAttrTensor(AbstractOperation * op_,const char * attr_name,AbstractTensorInterface * tensor,ForwardOperation * forward_op_)384 Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
385                      AbstractTensorInterface* tensor,
386                      ForwardOperation* forward_op_) {
387   return tensorflow::errors::Unimplemented(
388       "SetAttrTensor has not been implemented yet.");
389 }
SetAttrStringList(AbstractOperation * op_,const char * attr_name,const void * const * values,const size_t * lengths,int num_values,ForwardOperation * forward_op_)390 Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
391                          const void* const* values, const size_t* lengths,
392                          int num_values, ForwardOperation* forward_op_) {
393   std::vector<StringPiece> v(num_values);
394   for (int i = 0; i < num_values; ++i) {
395     v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
396   }
397   forward_op_->attrs.Set(attr_name, v);
398   return op_->SetAttrStringList(attr_name, values, lengths, num_values);
399 }
SetAttrFloatList(AbstractOperation * op_,const char * attr_name,const float * values,int num_values,ForwardOperation * forward_op_)400 Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
401                         const float* values, int num_values,
402                         ForwardOperation* forward_op_) {
403   forward_op_->attrs.Set(attr_name,
404                          gtl::ArraySlice<const float>(values, num_values));
405   return op_->SetAttrFloatList(attr_name, values, num_values);
406 }
SetAttrIntList(AbstractOperation * op_,const char * attr_name,const int64_t * values,int num_values,ForwardOperation * forward_op_)407 Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
408                       const int64_t* values, int num_values,
409                       ForwardOperation* forward_op_) {
410   forward_op_->attrs.Set(
411       attr_name, gtl::ArraySlice<const int64_t>(
412                      reinterpret_cast<const int64_t*>(values), num_values));
413   return op_->SetAttrIntList(attr_name, values, num_values);
414 }
SetAttrTypeList(AbstractOperation * op_,const char * attr_name,const DataType * values,int num_values,ForwardOperation * forward_op_)415 Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
416                        const DataType* values, int num_values,
417                        ForwardOperation* forward_op_) {
418   forward_op_->attrs.Set(attr_name,
419                          gtl::ArraySlice<const DataType>(values, num_values));
420   return op_->SetAttrTypeList(attr_name, values, num_values);
421 }
SetAttrBoolList(AbstractOperation * op_,const char * attr_name,const unsigned char * values,int num_values,ForwardOperation * forward_op_)422 Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
423                        const unsigned char* values, int num_values,
424                        ForwardOperation* forward_op_) {
425   std::unique_ptr<bool[]> b(new bool[num_values]);
426   for (int i = 0; i < num_values; ++i) {
427     b[i] = values[i];
428   }
429   forward_op_->attrs.Set(attr_name,
430                          gtl::ArraySlice<const bool>(b.get(), num_values));
431   return op_->SetAttrBoolList(attr_name, values, num_values);
432 }
SetAttrShapeList(AbstractOperation * op_,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,ForwardOperation * forward_op_)433 Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
434                         const int64_t** dims, const int* num_dims,
435                         int num_values, ForwardOperation* forward_op_) {
436   std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
437   for (int i = 0; i < num_values; ++i) {
438     const auto num_dims_i = num_dims[i];
439 
440     if (num_dims_i > TensorShape::MaxDimensions()) {
441       return errors::InvalidArgument(
442           strings::StrCat("Value specified for `", attr_name, "` has ",
443                           num_dims_i, " dimensions which is over the limit of ",
444                           TensorShape::MaxDimensions(), "."));
445     }
446     if (num_dims_i < 0) {
447       proto[i].set_unknown_rank(true);
448     } else {
449       const int64_t* dims_i = dims[i];
450       auto proto_i = &proto[i];
451       for (int d = 0; d < num_dims_i; ++d) {
452         proto_i->add_dim()->set_size(dims_i[d]);
453       }
454     }
455   }
456   forward_op_->attrs.Set(
457       attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
458   return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
459 }
SetAttrFunctionList(AbstractOperation * op_,const char * attr_name,absl::Span<const AbstractOperation * > values,ForwardOperation * forward_op_)460 Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
461                            absl::Span<const AbstractOperation*> values,
462                            ForwardOperation* forward_op_) {
463   return tensorflow::errors::Unimplemented(
464       "SetAttrFunctionList has not been "
465       "implemented yet.");
466 }
Execute(AbstractOperation * op_,AbstractContext * ctx,absl::Span<AbstractTensorHandle * > retvals,int * num_retvals,ForwardOperation * forward_op_,Tape * tape,const GradientRegistry & registry)467 Status Execute(AbstractOperation* op_, AbstractContext* ctx,
468                absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
469                ForwardOperation* forward_op_, Tape* tape,
470                const GradientRegistry& registry) {
471   TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
472   for (int i = 0; i < *num_retvals; i++) {
473     // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
474     forward_op_->outputs.push_back(retvals[i]);
475   }
476   // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
477   // attributes. Number type attrs and DataType attrs work fine without this.
478   // Consider getting rid of this and making the behavior between number types
479   // and string consistent.
480   forward_op_->attrs.BuildNodeDef();
481   std::unique_ptr<GradientFunction> gradient_fn;
482   TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
483   tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
484                         op_->Name());
485   return OkStatus();
486 }
487 }  // namespace internal
488 
489 }  // namespace gradients
490 }  // namespace tensorflow
491