1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ------------------------------------------------------------------------------*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LAZY_OP_RUNNER_H_ 17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LAZY_OP_RUNNER_H_ 18 19 #include "tensorflow/compiler/xla/stream_executor/dnn.h" 20 #include "tensorflow/compiler/xla/stream_executor/stream.h" 21 22 namespace stream_executor { 23 namespace dnn { 24 25 // A lazily-initialized OpRunner from an AlgorithmDesc. 26 // 27 // This exists to hold a choice of conv algorithm for a particular config, 28 // initialize its OpRunner at most once, and defer that initialization until the 29 // config is first needed. This allows AoT autotuning to load configurations 30 // for all convolutions it knows about, without doing expensive initialization 31 // (e.g. runtime codegen) and retaining non-negligible resources (e.g. compiled 32 // kernels) for potentially irrelevant configurations. It also enables XLA conv 33 // thunks to defer binding to a particular stream executor until the first run. 34 // 35 // `Op` must satisfy the following "concept": 36 // 37 // struct Op { 38 // // The function type signature parameter of an OpRunner. 39 // using Signature = _; 40 // 41 // // The parameter to be used by GetOrCreateRunner. 42 // struct Config; 43 // 44 // // Use a StreamExecutor to create an OpRunner. 45 // static StatusOr<OpRunner<Config>> OpRunnerFromDesc( 46 // const AlgorithmDesc& desc, Config config, StreamExecutor* stream); 47 // }; 48 template <typename Op> 49 class LazyOpRunner { 50 public: 51 // Construct from a pre-initialized OpRunner; all calls to GetOrCreateRunner 52 // will return a pointer to exactly this runner. FromOpRunner(std::unique_ptr<const OpRunner<typename Op::Signature>> runner)53 static port::StatusOr<std::unique_ptr<LazyOpRunner>> FromOpRunner( 54 std::unique_ptr<const OpRunner<typename Op::Signature>> runner) { 55 if (!runner) { 56 return port::InternalError("Null runner argument to FromOpRunner"); 57 } 58 TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); 59 // Private constructor cannot be called by make_unique :( 60 return {std::unique_ptr<LazyOpRunner>( 61 new LazyOpRunner(desc, std::move(runner)))}; 62 } 63 64 // Construct from an AlgorithmDesc, with no pre-initialized OpRunner; it will 65 // be created on the first call to GetOrCreateRunner. LazyOpRunner(AlgorithmDesc desc)66 explicit LazyOpRunner(AlgorithmDesc desc) : LazyOpRunner(desc, nullptr) {} 67 68 // Returns an already-initialized OpRunner if available, or creates one. 69 // 70 // Invariant: a particular instance of this class shall only receive calls 71 // with identical `config`s and `stream_executor`s. If the config is changed, 72 // only the first config it sees will have any effect, and second and 73 // subsequent configs will be ignored. If the stream executor is changed, 74 // some operations on the returned `OpRunner` using the changed stream 75 // executor will be errors. 76 // 77 // The result is owned by LazyOpRunner. GetOrCreateRunner(typename Op::Config config,Stream * stream)78 port::StatusOr<const OpRunner<typename Op::Signature>*> GetOrCreateRunner( 79 typename Op::Config config, Stream* stream) { 80 absl::MutexLock lock(&mu_); 81 if (!runner_) { 82 TF_ASSIGN_OR_RETURN(runner_, Op::RunnerFromAlgorithmDesc( 83 desc_, std::move(config), stream)); 84 } 85 return runner_.get(); 86 } 87 88 // Get the contained runner with the invariant that it's already initialized. GetRunner()89 port::StatusOr<const OpRunner<typename Op::Signature>*> GetRunner() { 90 absl::MutexLock lock(&mu_); 91 if (!runner_) { 92 return port::InternalError("LazyOpRunner::GetRunner: not initialized"); 93 } 94 return runner_.get(); 95 } 96 97 bool operator==(const LazyOpRunner& other) const { 98 return desc_ == other.desc_; 99 } 100 ToString()101 std::string ToString() const { return desc_.ToString(); } 102 ToAlgorithmDesc()103 const AlgorithmDesc& ToAlgorithmDesc() const { return desc_; } 104 105 private: LazyOpRunner(AlgorithmDesc desc,std::unique_ptr<const OpRunner<typename Op::Signature>> runner)106 LazyOpRunner(AlgorithmDesc desc, 107 std::unique_ptr<const OpRunner<typename Op::Signature>> runner) 108 : desc_(std::move(desc)), runner_(std::move(runner)) {} 109 110 AlgorithmDesc desc_; 111 absl::Mutex mu_; 112 std::unique_ptr<const OpRunner<typename Op::Signature>> runner_ 113 ABSL_GUARDED_BY(mu_); 114 }; 115 116 // Implementation of the concept required by LazyOpRunner, for ConvRunner. 117 struct ConvOp { 118 using Signature = ConvSignature; 119 120 struct Config { 121 ConvolutionKind kind; 122 DataType input_type, output_type; 123 const BatchDescriptor& input_descriptor; 124 const FilterDescriptor& filter_descriptor; 125 const BatchDescriptor& output_descriptor; 126 const ConvolutionDescriptor& convolution_descriptor; 127 }; 128 129 static port::StatusOr<std::unique_ptr<const OpRunner<ConvSignature>>> RunnerFromAlgorithmDescConvOp130 RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, 131 Stream* stream) { 132 return stream->ConvolveRunnerFromDesc( 133 desc, config.kind, config.input_type, config.output_type, 134 config.input_descriptor, config.filter_descriptor, 135 config.output_descriptor, config.convolution_descriptor); 136 } 137 }; 138 139 // Implementation of the concept required by LazyOpRunner, for LazyConvRunner. 140 struct FusedConvOp { 141 using Signature = FusedConvSignature; 142 143 struct Config { 144 ConvolutionKind kind; 145 DataType input_type, bias_type, output_type; 146 double conv_scale, side_input_scale, leakyrelu_alpha; 147 const BatchDescriptor& input_descriptor; 148 const FilterDescriptor& filter_descriptor; 149 const BatchDescriptor& bias_descriptor; 150 const BatchDescriptor& output_descriptor; 151 const ConvolutionDescriptor& convolution_descriptor; 152 ActivationMode activation_mode; 153 }; 154 155 static port::StatusOr<std::unique_ptr<const OpRunner<FusedConvSignature>>> RunnerFromAlgorithmDescFusedConvOp156 RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, 157 Stream* stream) { 158 return stream->FusedConvolveRunnerFromDesc( 159 desc, config.kind, config.input_type, config.bias_type, 160 config.output_type, config.conv_scale, config.side_input_scale, 161 config.leakyrelu_alpha, config.input_descriptor, 162 config.filter_descriptor, config.bias_descriptor, 163 config.output_descriptor, config.convolution_descriptor, 164 config.activation_mode); 165 } 166 }; 167 168 } // namespace dnn 169 } // namespace stream_executor 170 171 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_LAZY_OP_RUNNER_H_ 172