xref: /aosp_15_r20/external/pytorch/c10/core/Layout.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Backend.h>
4 #include <c10/util/Exception.h>
5 
6 #include <cstdint>
7 #include <ostream>
8 
9 namespace c10 {
10 enum class Layout : int8_t {
11   Strided,
12   Sparse,
13   SparseCsr,
14   Mkldnn,
15   SparseCsc,
16   SparseBsr,
17   SparseBsc,
18   Jagged,
19   NumOptions
20 };
21 
22 constexpr auto kStrided = Layout::Strided;
23 constexpr auto kSparse = Layout::Sparse;
24 constexpr auto kSparseCsr = Layout::SparseCsr;
25 constexpr auto kMkldnn = Layout::Mkldnn;
26 constexpr auto kSparseCsc = Layout::SparseCsc;
27 constexpr auto kSparseBsr = Layout::SparseBsr;
28 constexpr auto kSparseBsc = Layout::SparseBsc;
29 constexpr auto kJagged = Layout::Jagged;
30 
layout_from_backend(Backend backend)31 inline Layout layout_from_backend(Backend backend) {
32   switch (backend) {
33     case Backend::SparseCPU:
34     case Backend::SparseCUDA:
35     case Backend::SparseHIP:
36     case Backend::SparseVE:
37     case Backend::SparseXPU:
38     case Backend::SparsePrivateUse1:
39       return Layout::Sparse;
40     case Backend::MkldnnCPU:
41       return Layout::Mkldnn;
42     case Backend::SparseCsrCPU:
43     case Backend::SparseCsrCUDA:
44     case Backend::SparseCsrHIP:
45     case Backend::SparseCsrVE:
46     case Backend::SparseCsrXPU:
47       TORCH_CHECK(
48           false,
49           "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
50     default:
51       return Layout::Strided;
52   }
53 }
54 
55 inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
56   switch (layout) {
57     case at::kStrided:
58       return stream << "Strided";
59     case at::kSparse:
60       return stream << "Sparse";
61     case at::kSparseCsr:
62       return stream << "SparseCsr";
63     case at::kSparseCsc:
64       return stream << "SparseCsc";
65     case at::kSparseBsr:
66       return stream << "SparseBsr";
67     case at::kSparseBsc:
68       return stream << "SparseBsc";
69     case at::kMkldnn:
70       return stream << "Mkldnn";
71     case at::kJagged:
72       return stream << "Jagged";
73     default:
74       TORCH_CHECK(false, "Unknown layout");
75   }
76 }
77 
78 } // namespace c10
79