1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_ 16 #define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_ 17 18 #include "tensorflow/lite/c/common.h" 19 #include "tensorflow/lite/delegates/flex/delegate_data.h" 20 #include "tensorflow/lite/delegates/utils/simple_delegate.h" 21 22 namespace tflite { 23 24 namespace flex { 25 namespace testing { 26 class KernelTest; 27 } // namespace testing 28 } // namespace flex 29 30 // WARNING: This is an experimental interface that is subject to change. 31 // Delegate that can be used to extract parts of a graph that are designed to be 32 // executed by TensorFlow's runtime via Eager. 33 // 34 // The interpreter must be constructed after the FlexDelegate and destructed 35 // before the FlexDelegate. This delegate may be used with multiple 36 // interpreters, but it is *not* thread-safe. 37 // 38 // Usage: 39 // auto delegate = FlexDelegate::Create(); 40 // ... build interpreter ... 41 // 42 // if (delegate) { 43 // interpreter->ModifyGraphWithDelegate(delegate.get()); 44 // } 45 // 46 // void* delegate_data = delegate->data_; 47 // interpreter->SetCancellationFunction( 48 // delegate_data, 49 // FlexDelegate::HasCancelled); 50 // 51 // ... run inference ... 52 // 53 // static_cast<FlexDelegate*>(delegate_data)->Cancel(); 54 // 55 // ... destroy interpreter ... 56 // ... destroy delegate ... 57 class FlexDelegate : public SimpleDelegateInterface { 58 public: 59 friend class flex::testing::KernelTest; 60 61 // Creates a delegate that supports TF ops. Create()62 static TfLiteDelegateUniquePtr Create() { 63 return Create(/*base_delegate*/ nullptr); 64 } 65 ~FlexDelegate()66 ~FlexDelegate() override {} 67 mutable_data()68 flex::DelegateData* mutable_data() { return &delegate_data_; } 69 70 // This method is thread safe. It does two things: 71 // 1. Calls the CancellationManager of the TF eager runtime to support 72 // intra-op cancellation in TF. 73 // 2. Uses the CancellationManager to signal TFLite interpreter for inter-op 74 // cancellation. 75 // Training is non-recoverable after calling this API. 76 void Cancel(); 77 78 // The param `data` must be a pointer to a FlexDelegate instance. 79 static bool HasCancelled(void* data); 80 81 protected: 82 // We sometimes have to create certain stub data to test FlexDelegate. To 83 // achieve this, we will make a testing flex delegate class that inherits from 84 // FlexDelegate to override certain things for stub data creation. Therefore, 85 // this function accepts a FlexDelegate instance to initiliaze it properly for 86 // create a testing flex delegate in some cases, and it is only used in 87 // testing. 88 static TfLiteDelegateUniquePtr Create( 89 std::unique_ptr<FlexDelegate> base_delegate); 90 FlexDelegate()91 FlexDelegate() {} 92 93 const char* Name() const override; 94 95 bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration, 96 const TfLiteNode* node, 97 TfLiteContext* context) const override; 98 99 TfLiteStatus Initialize(TfLiteContext* context) override; 100 DelegateOptions()101 SimpleDelegateInterface::Options DelegateOptions() const override { 102 // Use default options. 103 return SimpleDelegateInterface::Options(); 104 } 105 106 std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface() 107 override; 108 109 TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, 110 TfLiteBufferHandle buffer_handle, 111 TfLiteTensor* output); 112 113 flex::DelegateData delegate_data_; 114 115 // Pointer to the base TfLiteDelegate which is created from the Create call. 116 TfLiteDelegate* base_delegate_ = nullptr; 117 118 private: 119 // A cancellation manager. 120 std::unique_ptr<tensorflow::CancellationManager> cancellation_manager_; 121 }; 122 123 } // namespace tflite 124 125 #endif // TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_ 126