xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/shape_partition.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
17 
18 namespace xla {
19 namespace cpu {
20 
Run(int64_t target_partition_count)21 std::vector<int64_t> ShapePartitionAssigner::Run(
22     int64_t target_partition_count) {
23   // Gather outer-most dims where dim_size >= 'target_partition_count'.
24   // This may include the inner-dim as LLVM can vectorize loops with dynamic
25   // bounds.
26   std::vector<int64_t> outer_dims;
27   int64_t outer_dim_size = 1;
28   // TODO(b/27458679) Consider reserving enough minor dimensions (based on
29   // target vector register width) to enable vector instructions.
30   for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) {
31     const int64_t dimension = shape_.layout().minor_to_major(i);
32     outer_dims.push_back(dimension);
33     outer_dim_size *= shape_.dimensions(dimension);
34     if (outer_dim_size >= target_partition_count) {
35       break;
36     }
37   }
38 
39   // Clip target partition count if outer dim size is insufficient to cover.
40   target_partition_count = std::min(outer_dim_size, target_partition_count);
41 
42   // Calculate the target number of partitions per-dimension, by factoring
43   // 'target_partition_count' into 'num_outer_dims' equal terms.
44   // EX:
45   // *) target_partition_count = 16
46   // *) out_dim_count = 2
47   // *) target_dim_partition_count = 16 ^ (1.0 / 2) == 4
48   const int64_t target_dim_partition_count = std::pow(
49       static_cast<double>(target_partition_count), 1.0 / outer_dims.size());
50 
51   // Assign feasible dimension partitions based on 'target_dim_partition_count'
52   // and actual dimension sizes from 'shape_'.
53   std::vector<int64_t> dimension_partition_counts(outer_dims.size());
54   for (int64_t i = 0; i < outer_dims.size(); ++i) {
55     dimension_partition_counts[i] =
56         std::min(static_cast<int64_t>(shape_.dimensions(outer_dims[i])),
57                  target_dim_partition_count);
58   }
59 
60   // Check if total partition count is below 'target_partition_count'.
61   // This can occur if some dimensions in 'shape_' are below the
62   // 'target_dim_partition_count' threshold.
63   if (GetTotalPartitionCount(dimension_partition_counts) <
64       target_partition_count) {
65     // Assign additional partitions (greedily to outer dimensions), if doing
66     // so would keep the total number of partitions <= 'target_partition_count',
67     // using one pass over 'dimension_partition_counts'.
68     for (int64_t i = 0; i < dimension_partition_counts.size(); ++i) {
69       const int64_t current_dim_partition_count = dimension_partition_counts[i];
70       const int64_t other_dims_partition_count =
71           GetTotalPartitionCount(dimension_partition_counts) /
72           current_dim_partition_count;
73       // Constraint: (current + additional) * other <= target
74       // Calculate: additional = target / other - current
75       int64_t additional_partition_count =
76           target_partition_count / other_dims_partition_count -
77           current_dim_partition_count;
78       // Clip 'additional_partition_count' by current dimension size.
79       additional_partition_count = std::min(
80           shape_.dimensions(outer_dims[i]) - dimension_partition_counts[i],
81           additional_partition_count);
82       if (additional_partition_count > 0) {
83         dimension_partition_counts[i] += additional_partition_count;
84       }
85     }
86   }
87 
88   return dimension_partition_counts;
89 }
90 
GetTotalPartitionCount(const std::vector<int64_t> & dimension_partition_counts)91 int64_t ShapePartitionAssigner::GetTotalPartitionCount(
92     const std::vector<int64_t>& dimension_partition_counts) {
93   int64_t total_partition_count = 1;
94   for (int64_t dim_partition_count : dimension_partition_counts) {
95     total_partition_count *= dim_partition_count;
96   }
97   return total_partition_count;
98 }
99 
ShapePartitionIterator(const Shape & shape,const std::vector<int64_t> & dimension_partition_counts)100 ShapePartitionIterator::ShapePartitionIterator(
101     const Shape& shape, const std::vector<int64_t>& dimension_partition_counts)
102     : shape_(shape),
103       dimension_partition_counts_(dimension_partition_counts),
104       dimensions_(dimension_partition_counts_.size()),
105       dimension_partition_sizes_(dimension_partition_counts_.size()),
106       dimension_partition_strides_(dimension_partition_counts_.size()) {
107   // Store partitioned outer dimensions from 'shape_'.
108   for (int i = 0; i < dimensions_.size(); ++i) {
109     dimensions_[i] = shape_.layout().minor_to_major(
110         shape_.layout().minor_to_major_size() - 1 - i);
111   }
112 
113   // Calculate partition size for each dimension (note that the size of
114   // the last partition in each dimension may be different if the dimension
115   // size is not a multiple of partition size).
116   for (int i = 0; i < dimension_partition_sizes_.size(); ++i) {
117     const int64_t dim_size = shape_.dimensions(dimensions_[i]);
118     dimension_partition_sizes_[i] =
119         std::max(int64_t{1}, dim_size / dimension_partition_counts_[i]);
120   }
121 
122   // Calculate the partition strides for each dimension.
123   dimension_partition_strides_[dimension_partition_strides_.size() - 1] = 1;
124   for (int i = dimension_partition_strides_.size() - 2; i >= 0; --i) {
125     dimension_partition_strides_[i] = dimension_partition_strides_[i + 1] *
126                                       dimension_partition_counts_[i + 1];
127   }
128 }
129 
GetPartition(int64_t index) const130 std::vector<std::pair<int64_t, int64_t>> ShapePartitionIterator::GetPartition(
131     int64_t index) const {
132   // Calculate and return the partition for 'index'.
133   // Returns for each dimension: (partition_start, partition_size).
134   std::vector<std::pair<int64_t, int64_t>> partition(dimensions_.size());
135   for (int64_t i = 0; i < partition.size(); ++i) {
136     // Calculate the index for dimension 'i'.
137     const int64_t partition_index = index / dimension_partition_strides_[i];
138     // Calculate dimension partition start at 'partition_index'.
139     partition[i].first = partition_index * dimension_partition_sizes_[i];
140     // Calculate dimension partition size (note that the last partition size
141     // may be adjusted if dimension size is not a multiple of partition size).
142     if (partition_index == dimension_partition_counts_[i] - 1) {
143       // Last partition in this dimension.
144       partition[i].second =
145           shape_.dimensions(dimensions_[i]) - partition[i].first;
146     } else {
147       partition[i].second = dimension_partition_sizes_[i];
148     }
149     CHECK_GT(partition[i].second, 0);
150     // Update index to remove contribution from current dimension.
151     index -= partition_index * dimension_partition_strides_[i];
152   }
153   return partition;
154 }
155 
GetTotalPartitionCount() const156 int64_t ShapePartitionIterator::GetTotalPartitionCount() const {
157   return ShapePartitionAssigner::GetTotalPartitionCount(
158       dimension_partition_counts_);
159 }
160 
161 }  // namespace cpu
162 }  // namespace xla
163