xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/autotune_maps/conv_parameters.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_PARAMETERS_H_
17 #define TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_PARAMETERS_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/platform/stream_executor.h"
22 #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h"
23 
24 namespace tensorflow {
25 // Uniquely identifies a convolution operation that runs on a particular device
26 // model.
27 //
28 // This can serve as a hashtable key, where the value might be the autotuned
29 // algorithm we choose for the conv.
30 //
31 // All of the data in this class other than the device_id is stored in the
32 // ConvParametersProto, so it can be easily serialized (for the purposes of
33 // ahead-of-time autotuning).
34 //
35 // When using the cudnn frontend API, two autotuning results for two different
36 // GPUs of the same model are not interchangeable, because an autotuning result
37 // includes a cudnn execution plan, which is tied to the GPU.  As a result, we
38 // need to create separate ConvParameters objects for them.
39 class ConvParameters {
40  public:
41   struct FusionInfo {
42     // For some implementations (e.g. cuDNN new backend) these scales are part
43     // of the algorithm, not part of the parameters an algorithm take. They need
44     // to be used to distinguish different algorithms.
45     double conv_scale;
46     double side_input_scale;
47     double leakyrelu_alpha;
48     stream_executor::dnn::ActivationMode activation_mode;
49     bool is_contrib;
50   };
51 
52   // LINT.IfChange(conv_parameters_version)
53   // A positive number that denotes the version of this class. Should be
54   // incremented everytime this class or ConvParametersProto are updated in a
55   // way that may invalidate autotune results.
56   static constexpr int kVersion = 2;
57   // LINT.ThenChange()
58 
59   // We have three kinds of convolutions today.  Vanilla unfused convolutions,
60   // fused convolutions, and fused convolutions as implemented in the `contrib`
61   // directory.  The two fused convolutions ultimately correspond to the same
62   // cudnn calls, but have slightly different semantics (e.g. they interpret
63   // padding differently).
64   ConvParameters(
65       int64_t batch, int64_t in_depths, absl::Span<const int64_t> in,
66       int data_format, int64_t out_depths, absl::Span<const int64_t> filter,
67       absl::Span<const int64_t> dilation, absl::Span<const int64_t> stride,
68       absl::Span<const int64_t> padding, DataType dtype, int device_id,
69       int group_count,
70       absl::optional<FusionInfo> fusion_info = absl::optional<FusionInfo>(),
71       // This argument should be set only for test use.
72       int version = kVersion);
73 
74   ConvParameters(int device_id, const ConvParametersProto& proto);
75 
76   bool operator==(const ConvParameters& other) const;
77 
78   bool operator!=(const ConvParameters& other) const {
79     return !(*this == other);
80   }
hash()81   uint64 hash() const { return hash_code_; }
82 
83   string ToString() const;
84 
proto()85   const ConvParametersProto& proto() const { return proto_; }
86 
87  private:
88   int device_id_;
89   ConvParametersProto proto_;
90   uint64 hash_code_;
91 };
92 }  // namespace tensorflow
93 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94 
95 #endif  // TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_PARAMETERS_H_
96