1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
17
18 namespace xla {
19 namespace cpu {
20
Run(int64_t target_partition_count)21 std::vector<int64_t> ShapePartitionAssigner::Run(
22 int64_t target_partition_count) {
23 // Gather outer-most dims where dim_size >= 'target_partition_count'.
24 // This may include the inner-dim as LLVM can vectorize loops with dynamic
25 // bounds.
26 std::vector<int64_t> outer_dims;
27 int64_t outer_dim_size = 1;
28 // TODO(b/27458679) Consider reserving enough minor dimensions (based on
29 // target vector register width) to enable vector instructions.
30 for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) {
31 const int64_t dimension = shape_.layout().minor_to_major(i);
32 outer_dims.push_back(dimension);
33 outer_dim_size *= shape_.dimensions(dimension);
34 if (outer_dim_size >= target_partition_count) {
35 break;
36 }
37 }
38
39 // Clip target partition count if outer dim size is insufficient to cover.
40 target_partition_count = std::min(outer_dim_size, target_partition_count);
41
42 // Calculate the target number of partitions per-dimension, by factoring
43 // 'target_partition_count' into 'num_outer_dims' equal terms.
44 // EX:
45 // *) target_partition_count = 16
46 // *) out_dim_count = 2
47 // *) target_dim_partition_count = 16 ^ (1.0 / 2) == 4
48 const int64_t target_dim_partition_count = std::pow(
49 static_cast<double>(target_partition_count), 1.0 / outer_dims.size());
50
51 // Assign feasible dimension partitions based on 'target_dim_partition_count'
52 // and actual dimension sizes from 'shape_'.
53 std::vector<int64_t> dimension_partition_counts(outer_dims.size());
54 for (int64_t i = 0; i < outer_dims.size(); ++i) {
55 dimension_partition_counts[i] =
56 std::min(static_cast<int64_t>(shape_.dimensions(outer_dims[i])),
57 target_dim_partition_count);
58 }
59
60 // Check if total partition count is below 'target_partition_count'.
61 // This can occur if some dimensions in 'shape_' are below the
62 // 'target_dim_partition_count' threshold.
63 if (GetTotalPartitionCount(dimension_partition_counts) <
64 target_partition_count) {
65 // Assign additional partitions (greedily to outer dimensions), if doing
66 // so would keep the total number of partitions <= 'target_partition_count',
67 // using one pass over 'dimension_partition_counts'.
68 for (int64_t i = 0; i < dimension_partition_counts.size(); ++i) {
69 const int64_t current_dim_partition_count = dimension_partition_counts[i];
70 const int64_t other_dims_partition_count =
71 GetTotalPartitionCount(dimension_partition_counts) /
72 current_dim_partition_count;
73 // Constraint: (current + additional) * other <= target
74 // Calculate: additional = target / other - current
75 int64_t additional_partition_count =
76 target_partition_count / other_dims_partition_count -
77 current_dim_partition_count;
78 // Clip 'additional_partition_count' by current dimension size.
79 additional_partition_count = std::min(
80 shape_.dimensions(outer_dims[i]) - dimension_partition_counts[i],
81 additional_partition_count);
82 if (additional_partition_count > 0) {
83 dimension_partition_counts[i] += additional_partition_count;
84 }
85 }
86 }
87
88 return dimension_partition_counts;
89 }
90
GetTotalPartitionCount(const std::vector<int64_t> & dimension_partition_counts)91 int64_t ShapePartitionAssigner::GetTotalPartitionCount(
92 const std::vector<int64_t>& dimension_partition_counts) {
93 int64_t total_partition_count = 1;
94 for (int64_t dim_partition_count : dimension_partition_counts) {
95 total_partition_count *= dim_partition_count;
96 }
97 return total_partition_count;
98 }
99
ShapePartitionIterator(const Shape & shape,const std::vector<int64_t> & dimension_partition_counts)100 ShapePartitionIterator::ShapePartitionIterator(
101 const Shape& shape, const std::vector<int64_t>& dimension_partition_counts)
102 : shape_(shape),
103 dimension_partition_counts_(dimension_partition_counts),
104 dimensions_(dimension_partition_counts_.size()),
105 dimension_partition_sizes_(dimension_partition_counts_.size()),
106 dimension_partition_strides_(dimension_partition_counts_.size()) {
107 // Store partitioned outer dimensions from 'shape_'.
108 for (int i = 0; i < dimensions_.size(); ++i) {
109 dimensions_[i] = shape_.layout().minor_to_major(
110 shape_.layout().minor_to_major_size() - 1 - i);
111 }
112
113 // Calculate partition size for each dimension (note that the size of
114 // the last partition in each dimension may be different if the dimension
115 // size is not a multiple of partition size).
116 for (int i = 0; i < dimension_partition_sizes_.size(); ++i) {
117 const int64_t dim_size = shape_.dimensions(dimensions_[i]);
118 dimension_partition_sizes_[i] =
119 std::max(int64_t{1}, dim_size / dimension_partition_counts_[i]);
120 }
121
122 // Calculate the partition strides for each dimension.
123 dimension_partition_strides_[dimension_partition_strides_.size() - 1] = 1;
124 for (int i = dimension_partition_strides_.size() - 2; i >= 0; --i) {
125 dimension_partition_strides_[i] = dimension_partition_strides_[i + 1] *
126 dimension_partition_counts_[i + 1];
127 }
128 }
129
GetPartition(int64_t index) const130 std::vector<std::pair<int64_t, int64_t>> ShapePartitionIterator::GetPartition(
131 int64_t index) const {
132 // Calculate and return the partition for 'index'.
133 // Returns for each dimension: (partition_start, partition_size).
134 std::vector<std::pair<int64_t, int64_t>> partition(dimensions_.size());
135 for (int64_t i = 0; i < partition.size(); ++i) {
136 // Calculate the index for dimension 'i'.
137 const int64_t partition_index = index / dimension_partition_strides_[i];
138 // Calculate dimension partition start at 'partition_index'.
139 partition[i].first = partition_index * dimension_partition_sizes_[i];
140 // Calculate dimension partition size (note that the last partition size
141 // may be adjusted if dimension size is not a multiple of partition size).
142 if (partition_index == dimension_partition_counts_[i] - 1) {
143 // Last partition in this dimension.
144 partition[i].second =
145 shape_.dimensions(dimensions_[i]) - partition[i].first;
146 } else {
147 partition[i].second = dimension_partition_sizes_[i];
148 }
149 CHECK_GT(partition[i].second, 0);
150 // Update index to remove contribution from current dimension.
151 index -= partition_index * dimension_partition_strides_[i];
152 }
153 return partition;
154 }
155
GetTotalPartitionCount() const156 int64_t ShapePartitionIterator::GetTotalPartitionCount() const {
157 return ShapePartitionAssigner::GetTotalPartitionCount(
158 dimension_partition_counts_);
159 }
160
161 } // namespace cpu
162 } // namespace xla
163