1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ 18 19 #include <memory> 20 #include <optional> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 25 namespace xla { 26 namespace dot_as_convolution_util { 27 28 // Type of Batch representation for a convolution that has a spatial dimension 29 // that is effectively a batch dimension. We currently have two 30 // representations that we detect as "batch equivalent" and this enum allows 31 // differentiating between the two. 32 enum class SpatialBatchRepresentation { 33 kNone, 34 kUnpaddedVersion, 35 kPaddedVersion, 36 }; 37 38 // Describes the dimensions of a convolution that can be interpreted as a dot 39 // or a normal convolution. 40 struct DotConvolutionDimsInfo { 41 // The dimension numbers for the operands and output corresponding to a 42 // logical dimension (e.g., batch, contracting, non-contracting). If an 43 // operand or the output doesn't have the logical dimension, it is set to 44 // -1. 45 struct DimNums { 46 int64_t lhs; 47 int64_t rhs; 48 int64_t output; 49 // The corresponding spatial dimension in the convolution's config. Set to 50 // -1 if it's not mapped to a spatial dimension. 51 int64_t spatial_dim; 52 }; 53 std::vector<DimNums> batch_dims; 54 std::vector<DimNums> contracting_dims; 55 std::vector<DimNums> lhs_non_contracting_dims; 56 std::vector<DimNums> rhs_non_contracting_dims; 57 std::vector<DimNums> conv_spatial_dims; 58 }; 59 60 // Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can 61 // be interpreted as a dot, there is no conv_spatial_dims. 62 DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv); 63 64 // Creates sharded convolution instruction that can be interpreted as a dot. 65 // This is a utility for per-op partitioners. 66 // - 'conv' is the original convolution instruction. 67 // - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'. 68 // - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result 69 // convolution instruction. 70 StatusOr<std::unique_ptr<HloInstruction>> 71 CreateShardedConvForDotGeneralConvolution( 72 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, 73 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); 74 75 // Check if a spatial dim is parallel batch dimension. 76 // A parallel batch dimension in DotGeneral is represented as a spatial 77 // dimension with window size B (batch dimension size), stride B - 1, and base 78 // dilation B or an alternative representation of window size B, stride B, 79 // padding low/high B - 1, base dilation B - 1 and window reversal 80 SpatialBatchRepresentation SpatialIsBatch(int64_t lhs_spatial_size, 81 const WindowDimension& spatial_wd); 82 // Returns if the spatial dimension represented by 'spatial_wd' is an LHS non 83 // contracting dimension. 84 bool SpatialIsLhsNonContracting(int64_t rhs_spatial_size, 85 const WindowDimension& spatial_wd); 86 // Returns if the spatial dimension represented by 'spatial_wd' is an RHS non 87 // contracting dimension. 88 bool SpatialIsRhsNonContracting(int64_t lhs_spatial_size, 89 int64_t rhs_spatial_size, 90 const WindowDimension& spatial_wd); 91 // Returns if the spatial dimension represented by 'spatial_wd' endsup being 92 // equivalent to a contracting dimension. 93 bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size, 94 const WindowDimension& spatial_wd); 95 // Returns a DotConvolutionDimsInfo from a kDot instruction, where all 96 // the spatial_dim values are set to -1. 97 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot); 98 99 } // namespace dot_as_convolution_util 100 } // namespace xla 101 102 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ 103