xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/layout_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/layout_util.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <functional>
22 #include <numeric>
23 #include <random>
24 #include <string>
25 #include <vector>
26 
27 #include "absl/hash/hash.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace xla {
39 namespace {
40 
41 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
42 // minor_to_major to the value that represents the default layout.
43 template <typename T>
SetDefaultLayoutToContainer(T * minor_to_major)44 void SetDefaultLayoutToContainer(T* minor_to_major) {
45   // The default XLA layout is major-to-minor (dim 0 is major).
46   // For more information on XLA layouts, see:
47   // https://www.tensorflow.org/performance/xla/shapes
48   const int64_t size = minor_to_major->size();
49   for (int64_t i = 0; i < size; ++i) {
50     (*minor_to_major)[i] = size - 1 - i;
51   }
52 }
53 
54 }  // namespace
55 
MakeLayout(absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)56 /* static */ Layout LayoutUtil::MakeLayout(
57     absl::Span<const int64_t> minor_to_major,
58     absl::Span<const DimLevelType> dim_level_types,
59     absl::Span<const Tile> tiles, int64_t element_size_in_bits,
60     int64_t memory_space) {
61   Layout layout;
62   for (int64_t dimension_number : minor_to_major) {
63     layout.add_minor_to_major(dimension_number);
64   }
65   for (DimLevelType dim_level_type : dim_level_types) {
66     layout.add_dim_level_type(dim_level_type);
67   }
68   for (const Tile& tile : tiles) {
69     for (int64_t dim : tile.dimensions()) {
70       if (dim < 0 && dim != Tile::kCombineDimension) {
71         LOG(FATAL)
72             << "Tile dimension size needs to be minimum int64_t value if "
73                "it's negative. Value is "
74             << dim;
75       }
76     }
77     *layout.add_tiles() = tile;
78   }
79   layout.set_element_size_in_bits(element_size_in_bits);
80   layout.set_memory_space(memory_space);
81   return layout;
82 }
83 
MakeDescendingLayout(int64_t rank)84 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64_t rank) {
85   std::vector<int64_t> layout(rank);
86   std::iota(layout.rbegin(), layout.rend(), static_cast<int64_t>(0));
87   return MakeLayout(layout);
88 }
89 
MakeAscendingLayout(int64_t rank)90 /* static */ Layout LayoutUtil::MakeAscendingLayout(int64_t rank) {
91   std::vector<int64_t> layout(rank);
92   std::iota(layout.begin(), layout.end(), static_cast<int64_t>(0));
93   return MakeLayout(layout);
94 }
95 
MakeLayoutFromMajorToMinor(absl::Span<const int64_t> major_to_minor)96 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
97     absl::Span<const int64_t> major_to_minor) {
98   Layout layout;
99   for (int i = major_to_minor.size() - 1; i >= 0; i--) {
100     layout.add_minor_to_major(major_to_minor[i]);
101   }
102   return layout;
103 }
104 
105 namespace {
106 
107 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64_t rank)108 Layout CreateDefaultLayoutForRank(int64_t rank) {
109   Layout layout;
110   auto* minor_to_major = layout.mutable_minor_to_major();
111   minor_to_major->resize(rank, 0);
112   SetDefaultLayoutToContainer(minor_to_major);
113   return layout;
114 }
115 
116 }  // namespace
117 
GetDefaultLayoutForShape(const Shape & shape)118 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
119   if (shape.IsOpaque() || shape.IsToken()) {
120     // Opaque and token types have empty layouts.
121     return Layout();
122   }
123 
124   // A Layout proto corresponds to a single array, not a tuple.
125   CHECK(shape.IsArray());
126   return CreateDefaultLayoutForRank(shape.dimensions_size());
127 }
128 
GetDefaultLayoutForRank(int64_t rank)129 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64_t rank) {
130   return CreateDefaultLayoutForRank(rank);
131 }
132 
GetDefaultLayoutForR2()133 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
134   return CreateDefaultLayoutForRank(2);
135 }
136 
GetDefaultLayoutForR3()137 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
138   return CreateDefaultLayoutForRank(3);
139 }
140 
GetDefaultLayoutForR4()141 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
142   return CreateDefaultLayoutForRank(4);
143 }
144 
SetToDefaultLayout(Shape * shape)145 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
146   if (shape->IsTuple()) {
147     // Tuple shape.
148     for (auto& element_shape : *shape->mutable_tuple_shapes()) {
149       SetToDefaultLayout(&element_shape);
150     }
151     shape->clear_layout();
152   } else if (shape->IsArray()) {
153     auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
154     minor_to_major->resize(shape->dimensions_size(), 0);
155     SetDefaultLayoutToContainer(minor_to_major);
156   } else {
157     // Opaque, token types etc. have no layout.
158     shape->clear_layout();
159   }
160 }
161 
GetWithDefaultLayout(const Shape & shape)162 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
163   Shape copy(shape);
164   LayoutUtil::SetToDefaultLayout(&copy);
165   return copy;
166 }
167 
SetToDefaultLayout(ProgramShape * program_shape)168 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
169   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
170     LayoutUtil::SetToDefaultLayout(&parameter_shape);
171   }
172   LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
173 }
174 
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)175 /* static */ Status LayoutUtil::ValidateLayoutInShape(
176     const Shape& shape, bool allow_missing_layouts) {
177   if (shape.IsTuple()) {
178     // Tuple shape.
179     if (shape.has_layout()) {
180       return InvalidArgument("tuple should not have a layout field");
181     }
182     for (auto& element_shape : shape.tuple_shapes()) {
183       TF_RETURN_IF_ERROR(
184           ValidateLayoutInShape(element_shape, allow_missing_layouts));
185     }
186     return OkStatus();
187   } else if (shape.IsArray()) {
188     if (!shape.has_layout()) {
189       if (allow_missing_layouts) {
190         return OkStatus();
191       }
192       return InvalidArgument("shape %s does not have a layout",
193                              ShapeUtil::HumanString(shape));
194     }
195     return ValidateLayoutForShape(shape.layout(), shape);
196   } else {
197     // Token, opaque, etc. shape.
198     if (shape.has_layout()) {
199       return InvalidArgument(
200           "shape of primitive type %s should not have a layout",
201           PrimitiveType_Name(shape.element_type()));
202     }
203     return OkStatus();
204   }
205 }
206 
ValidateLayoutForShape(const Layout & layout,const Shape & shape)207 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
208                                                        const Shape& shape) {
209   if (shape.IsTuple()) {
210     return InvalidArgument("a single Layout is not valid for tuple shapes");
211   }
212 
213   if (!shape.IsArray()) {
214     if (layout.minor_to_major_size() != 0) {
215       return InvalidArgument(
216           "shape of primitive type %s should not have a non-trivial layout",
217           PrimitiveType_Name(shape.element_type()));
218     }
219     return OkStatus();
220   }
221 
222   if (layout.minor_to_major_size() != shape.rank()) {
223     return InvalidArgument(
224         "layout minor_to_major field contains %d elements, "
225         "but shape is rank %d: {%s}; shape: %s",
226         layout.minor_to_major_size(), shape.rank(),
227         absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString());
228   }
229 
230   std::vector<bool> dimensions_in_layout(shape.rank(), false);
231   for (int64_t i = 0; i < shape.rank(); ++i) {
232     int64_t dim = layout.minor_to_major(i);
233     if (dim < 0 || dim >= shape.rank()) {
234       return InvalidArgument(
235           "layout minor_to_major field has out-of-bounds value: %s",
236           HumanString(layout));
237     }
238     if (dimensions_in_layout[dim]) {
239       return InvalidArgument(
240           "layout minor_to_major field has duplicate values: {%s}",
241           HumanString(layout));
242     }
243     dimensions_in_layout[dim] = true;
244   }
245 
246   if (!layout.dim_level_types().empty()) {
247     if (layout.dim_level_types().size() != shape.rank()) {
248       return InvalidArgument(
249           "layout dim_level_types field contains %d elements, but shape is "
250           "rank %d: {%s}; shape: %s",
251           layout.dim_level_types_size(), shape.rank(),
252           absl::StrJoin(layout.dim_level_types(), ", ",
253                         [](std::string* out, DimLevelType dim_level_type) {
254                           absl::StrAppend(out,
255                                           DimLevelType_Name(dim_level_type));
256                         }),
257           shape.ShortDebugString());
258     }
259     if (LayoutUtil::IsSparse(layout)) {
260       if (layout.tiles_size() > 0) {
261         return InvalidArgument(
262             "layout has tiles, but the shape is a sparse array: %s",
263             shape.ShortDebugString());
264       }
265     }
266   }
267 
268   return OkStatus();
269 }
270 
ClearLayout(Shape * shape)271 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
272   shape->clear_layout();
273   for (auto& element_shape : *shape->mutable_tuple_shapes()) {
274     ClearLayout(&element_shape);
275   }
276 }
277 
ClearLayout(ProgramShape * program_shape)278 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
279   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
280     LayoutUtil::ClearLayout(&parameter_shape);
281   }
282   LayoutUtil::ClearLayout(program_shape->mutable_result());
283 }
284 
ClearTiles(Shape * shape)285 /* static */ void LayoutUtil::ClearTiles(Shape* shape) {
286   ShapeUtil::ForEachMutableSubshape(
287       shape, [](Shape* subshape, const ShapeIndex&) {
288         if (subshape->has_layout()) {
289           if (subshape->has_layout()) {
290             subshape->mutable_layout()->clear_tiles();
291           }
292         }
293       });
294 }
295 
IsDenseArray(const Shape & shape)296 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
297   return shape.IsArray() && (!shape.has_layout() || IsDense(shape.layout()));
298 }
299 
IsSparseArray(const Shape & shape)300 /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
301   return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
302 }
303 
IsCOOArray(const Shape & shape)304 /* static */ bool LayoutUtil::IsCOOArray(const Shape& shape) {
305   return shape.IsArray() && shape.has_layout() && IsCOO(shape.layout());
306 }
307 
IsCSRArray(const Shape & shape)308 /* static */ bool LayoutUtil::IsCSRArray(const Shape& shape) {
309   return shape.IsArray() && shape.rank() == 2 && shape.has_layout() &&
310          IsCSR(shape.layout());
311 }
312 
IsCSCArray(const Shape & shape)313 /* static */ bool LayoutUtil::IsCSCArray(const Shape& shape) {
314   return shape.IsArray() && shape.rank() == 2 && shape.has_layout() &&
315          IsCSC(shape.layout());
316 }
317 
IsDense(const Layout & layout)318 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
319   return absl::c_all_of(
320       layout.dim_level_types(),
321       [](DimLevelType dim_level_type) { return dim_level_type == DIM_DENSE; });
322 }
323 
IsSparse(const Layout & layout)324 /* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
325   return !IsDense(layout);
326 }
327 
IsCOO(const Layout & layout)328 /* static */ bool LayoutUtil::IsCOO(const Layout& layout) {
329   return !layout.dim_level_types().empty() &&
330          layout.dim_level_type(0) == DIM_COMPRESSED &&
331          absl::c_all_of(layout.dim_level_types().subspan(1),
332                         [](DimLevelType dim_level_type) {
333                           return dim_level_type == DIM_SINGLETON;
334                         });
335 }
336 
IsCSR(const Layout & layout)337 /* static */ bool LayoutUtil::IsCSR(const Layout& layout) {
338   return IsMonotonicWithDim0Major(layout) &&
339          layout.dim_level_types() ==
340              absl::Span<const DimLevelType>{DIM_DENSE, DIM_COMPRESSED};
341 }
342 
IsCSC(const Layout & layout)343 /* static */ bool LayoutUtil::IsCSC(const Layout& layout) {
344   return IsMonotonicWithDim0Minor(layout) &&
345          layout.dim_level_types() ==
346              absl::Span<const DimLevelType>{DIM_DENSE, DIM_COMPRESSED};
347 }
348 
IsMonotonicWithDim0Minor(const Layout & layout)349 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
350   return std::is_sorted(layout.minor_to_major().begin(),
351                         layout.minor_to_major().end());
352 }
353 
IsMonotonicWithDim0Major(const Layout & layout)354 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
355   return std::is_sorted(layout.minor_to_major().begin(),
356                         layout.minor_to_major().end(), std::greater<int64_t>());
357 }
358 
HasLayout(const Shape & shape)359 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
360   if (shape.IsTuple()) {
361     // Tuple shape: all subshapes must have a layout.
362     return absl::c_all_of(shape.tuple_shapes(),
363                           [](const Shape& s) { return HasLayout(s); });
364   } else if (!shape.IsArray()) {
365     // Opaque, token types etc. ignore layout.
366     return true;
367   }
368   return shape.has_layout();
369 }
370 
HasLayout(const ProgramShape & program_shape)371 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
372   for (auto& parameter_shape : program_shape.parameters()) {
373     if (!LayoutUtil::HasLayout(parameter_shape)) {
374       return false;
375     }
376   }
377   return LayoutUtil::HasLayout(program_shape.result());
378 }
379 
Equal(const Layout & lhs,const Layout & rhs)380 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
381   return lhs == rhs;
382 }
383 
MinorToMajor(const Shape & shape)384 /* static */ absl::Span<const int64_t> LayoutUtil::MinorToMajor(
385     const Shape& shape) {
386   CHECK(shape.IsArray());
387   return shape.layout().minor_to_major();
388 }
389 
MinorToMajor(const Layout & layout)390 /* static */ absl::Span<const int64_t> LayoutUtil::MinorToMajor(
391     const Layout& layout) {
392   return layout.minor_to_major();
393 }
394 
Major(const Layout & layout,int64_t physical_dimension_number)395 /* static */ int64_t LayoutUtil::Major(const Layout& layout,
396                                        int64_t physical_dimension_number) {
397   CHECK_LE(0, physical_dimension_number);
398   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
399   return Minor(layout,
400                layout.minor_to_major_size() - 1 - physical_dimension_number);
401 }
402 
Minor(const Layout & layout,int64_t physical_dimension_number)403 /* static */ int64_t LayoutUtil::Minor(const Layout& layout,
404                                        int64_t physical_dimension_number) {
405   CHECK_LE(0, physical_dimension_number);
406   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
407   return layout.minor_to_major(physical_dimension_number);
408 }
409 
MakeLogicalToPhysical(const Layout & layout)410 /* static */ std::vector<int64_t> LayoutUtil::MakeLogicalToPhysical(
411     const Layout& layout) {
412   std::vector<int64_t> logical_to_physical(layout.minor_to_major_size());
413   for (int64_t physical = 0, end = logical_to_physical.size(); physical < end;
414        ++physical) {
415     const int64_t logical = Major(layout, physical);
416     logical_to_physical[logical] = physical;
417   }
418   return logical_to_physical;
419 }
420 
HumanString(const Layout & layout)421 /* static */ std::string LayoutUtil::HumanString(const Layout& layout) {
422   return layout.ToString();
423 }
424 
425 namespace {
426 
427 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)428 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
429   if (src.IsTuple() != dst->IsTuple()) {
430     return InvalidArgument(
431         "cannot copy layout from shape: shape structure differs");
432   }
433   if (src.IsTuple()) {
434     if (ShapeUtil::TupleElementCount(src) !=
435         ShapeUtil::TupleElementCount(*dst)) {
436       return InvalidArgument(
437           "cannot copy layout from shape: tuple element count differs");
438     }
439     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
440       TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
441                                             dst->mutable_tuple_shapes(i)));
442     }
443   } else {
444     if (src.has_layout()) {
445       if (src.rank() != dst->rank()) {
446         return InvalidArgument("cannot copy layout from shape: ranks differs");
447       }
448       TF_RETURN_IF_ERROR(
449           LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
450       *dst->mutable_layout() = src.layout();
451     } else {
452       dst->clear_layout();
453     }
454   }
455   return OkStatus();
456 }
457 
458 }  // namespace
459 
460 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)461 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
462   return CopyLayoutInternal(src, dst);
463 }
464 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)465 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
466                                                    const Shape& rhs) {
467   if (lhs.IsTuple()) {
468     if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
469                               ShapeUtil::TupleElementCount(rhs)) {
470       return false;
471     }
472     for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
473       if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
474         return false;
475       }
476     }
477     return true;
478   }
479   if (lhs.IsArray()) {
480     if (lhs.rank() != rhs.rank()) {
481       return false;
482     }
483     if (!lhs.has_layout() && !rhs.has_layout()) {
484       return true;
485     }
486     if (!lhs.has_layout() || !rhs.has_layout()) {
487       return false;
488     }
489     return LayoutUtil::Equal(lhs.layout(), rhs.layout());
490   }
491   // Layouts of non-array and non-tuple shapes is ignored.
492   return true;
493 }
494 
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64_t> dims)495 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
496     const Layout& layout, absl::Span<const int64_t> dims) {
497   absl::InlinedVector<int64_t, 8> positions_in_layout;
498   for (int64_t dim : dims) {
499     positions_in_layout.push_back(
500         PositionInContainer(layout.minor_to_major(), dim));
501   }
502   absl::c_sort(positions_in_layout);
503   for (size_t i = 1; i < positions_in_layout.size(); ++i) {
504     if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
505       return false;
506     }
507   }
508   return true;
509 }
510 
MoveDimToMajor(const Layout & layout,int64_t dim)511 /*static*/ Layout LayoutUtil::MoveDimToMajor(const Layout& layout,
512                                              int64_t dim) {
513   if (dim == MinorToMajor(layout).back()) return layout;
514   Layout ret = layout;
515   ret.clear_minor_to_major();
516   for (auto d : MinorToMajor(layout)) {
517     if (d != dim) {
518       ret.add_minor_to_major(d);
519     }
520   }
521   ret.add_minor_to_major(dim);
522   return ret;
523 }
524 
LinearIndex(const Shape & shape,absl::Span<const int64_t> indices)525 /*static*/ int64_t LayoutUtil::LinearIndex(const Shape& shape,
526                                            absl::Span<const int64_t> indices) {
527   CHECK(shape.IsArray());
528   CHECK(shape.has_layout());
529   const int rank = shape.rank();
530   CHECK_EQ(rank, indices.size());
531 
532   if (rank == 0) {
533     return 0;
534   }
535   if (rank == 1) {
536     return indices[0];
537   }
538 
539   Tile tile = {};
540   if (!shape.layout().tiles().empty()) {
541     tile = shape.layout().tiles()[0];
542   }
543 
544   int64_t linear_index = 0;
545   int64_t tile_multiplier = 1;
546   // Initialize to number of elements in a tile.
547   for (int64_t i : tile.dimensions()) {
548     tile_multiplier *= i;
549   }
550   int64_t within_tile_multiplier = 1;
551 
552   // We only look at the top-level tile.
553   for (int64_t minor = 0; minor < rank; minor++) {
554     int64_t logical_dim = Minor(shape.layout(), minor);
555     int64_t shape_dim_size = shape.dimensions(logical_dim);
556     int64_t index = indices[logical_dim];
557 
558     if (minor < tile.dimensions().size()) {
559       int64_t tile_dim_size =
560           tile.dimensions()[tile.dimensions().size() - 1 - minor];
561       linear_index += tile_multiplier * (index / tile_dim_size) +
562                       within_tile_multiplier * (index % tile_dim_size);
563       tile_multiplier *= CeilOfRatio(shape_dim_size, tile_dim_size);
564       within_tile_multiplier *= tile_dim_size;
565     } else {
566       linear_index += index * tile_multiplier;
567       tile_multiplier *= shape_dim_size;
568     }
569   }
570   return linear_index;
571 }
572 
MemorySpace(const Shape & shape)573 /*static*/ int64_t LayoutUtil::MemorySpace(const Shape& shape) {
574   return shape.has_layout() ? shape.layout().memory_space()
575                             : Layout::kDefaultMemorySpace;
576 }
577 
578 }  // namespace xla
579