xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/abstract_operation.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
16 #define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
17 
18 #include <memory>
19 
20 #include "absl/types/span.h"
21 #include "tensorflow/c/eager/abstract_tensor_handle.h"
22 #include "tensorflow/c/tensor_interface.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/status.h"
26 
27 namespace tensorflow {
28 
29 // Abstract interface to an operation.
30 // This interface allows building and executing an operation in either
31 // tracing or immediate execution mode.
32 class AbstractOperation {
33  protected:
34   enum AbstractOperationKind {
35     kGraph,
36     kMlir,
37     kEager,
38     kTfrt,
39     kTape,
40     kOpHandler
41   };
AbstractOperation(AbstractOperationKind kind)42   explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
~AbstractOperation()43   virtual ~AbstractOperation() {}
44 
45  public:
getKind()46   AbstractOperationKind getKind() const { return kind_; }
47 
48   // Release any underlying resources, including the interface object.
49   //
50   // WARNING: The destructor of this class is marked as protected to disallow
51   // clients from directly destroying this object since it may manage it's own
52   // lifetime through ref counting. Thus this must be allocated on the heap and
53   // clients MUST call Release() in order to destroy an instance of this class.
54   virtual void Release() = 0;
55 
56   virtual Status Reset(const char* op, const char* raw_device_name) = 0;
57 
58   virtual const string& Name() const = 0;
59 
60   // Returns the operation's device name.
61   //
62   // The value returned may be different from the one set by SetDeviceName, but
63   // it will be compatible with it: the name will be updated by device placement
64   // logic to refer to the specific device chosen.
65   //
66   // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
67   // returned by DeviceName should be "/device:GPU:*" until a particular GPU is
68   // chosen for the operation by the device placement logic in the
69   // executor. After that, the value returned by DeviceName will be a full
70   // device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
71   virtual const string& DeviceName() const = 0;
72 
73   // Sets the operation device name.
74   //
75   // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
76   // the result will be used as a constraint for device placement. See the
77   // documentation for DeviceName for more details.
78   //
79   // The value will override the previous value - that is, no "merging" of
80   // existing and given constraints will be performed.
81   virtual Status SetDeviceName(const char* name) = 0;
82 
83   virtual Status AddInput(AbstractTensorHandle* input) = 0;
84   virtual Status AddInputList(
85       absl::Span<AbstractTensorHandle* const> inputs) = 0;
86   virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
87                          int* num_retvals) = 0;
88 
89   virtual Status SetAttrString(const char* attr_name, const char* data,
90                                size_t length) = 0;
91   virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
92   virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
93   virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
94   virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
95   virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
96                               const int num_dims) = 0;
97   virtual Status SetAttrShape(const char* attr_name,
98                               const PartialTensorShape shape);
99   virtual Status SetAttrFunction(const char* attr_name,
100                                  const AbstractOperation* value) = 0;
101   virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
102                                      size_t length) = 0;
103   virtual Status SetAttrTensor(const char* attr_name,
104                                AbstractTensorInterface* tensor) = 0;
105   virtual Status SetAttrStringList(const char* attr_name,
106                                    const void* const* values,
107                                    const size_t* lengths, int num_values) = 0;
108   virtual Status SetAttrStringList(const char* attr_name,
109                                    absl::Span<string const> values);
110   virtual Status SetAttrFloatList(const char* attr_name, const float* values,
111                                   int num_values) = 0;
112   virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
113                                 int num_values) = 0;
114   virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
115                                  int num_values) = 0;
116   virtual Status SetAttrBoolList(const char* attr_name,
117                                  const unsigned char* values,
118                                  int num_values) = 0;
119   virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
120                                   const int* num_dims, int num_values) = 0;
121   virtual Status SetAttrFunctionList(
122       const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
123 
124  private:
125   const AbstractOperationKind kind_;
126 };
127 
128 // TODO(b/193656009): Defining these in a cc file causes linker errors with
129 // fastbuild.
SetAttrShape(const char * attr_name,const PartialTensorShape shape)130 inline Status AbstractOperation::SetAttrShape(const char* attr_name,
131                                               const PartialTensorShape shape) {
132   return SetAttrShape(attr_name, shape.dim_sizes().data(), shape.dims());
133 }
134 
SetAttrStringList(const char * attr_name,absl::Span<string const> values)135 inline Status AbstractOperation::SetAttrStringList(
136     const char* attr_name, absl::Span<string const> values) {
137   std::vector<const char*> raw_strs;
138   std::vector<size_t> lengths;
139   raw_strs.reserve(values.size());
140   lengths.reserve(values.size());
141   for (const auto& s : values) {
142     raw_strs.emplace_back(s.data());
143     lengths.emplace_back(s.size());
144   }
145   return SetAttrStringList(attr_name,
146                            reinterpret_cast<const void**>(raw_strs.data()),
147                            lengths.data(), values.size());
148 }
149 
150 namespace internal {
151 struct AbstractOperationDeleter {
operatorAbstractOperationDeleter152   void operator()(AbstractOperation* p) const {
153     if (p != nullptr) {
154       p->Release();
155     }
156   }
157 };
158 }  // namespace internal
159 
160 using AbstractOperationPtr =
161     std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
162 
163 }  // namespace tensorflow
164 
165 #endif  // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
166