1 #include <torch/csrc/jit/tensorexpr/operators/softmax.h>
2
3 namespace torch::jit::tensorexpr {
4
5 using namespace torch::jit::tensorexpr;
6
computeSoftmax(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,bool log_softmax)7 Tensor computeSoftmax(
8 const std::vector<ArgValue>& inputs,
9 const std::vector<ExprHandle>& outputShape,
10 const std::vector<ExprHandle>& outputStrides,
11 bool log_softmax) {
12 // Softmax is computed as follows:
13 // softmax(vi) = exp(vi) / sum(exp(vi))
14 //
15 // In order to avoid overflow issues due to exp of a large number, we
16 // subtract the max of that dim before computing exp.
17 // softmax(vi) = exp(vi - max(vi)) / sum(exp(vi - max(vi)))
18 //
19 // This is implemented as 4 loopnests:
20 // - First loop computes the max over the softmax dim.
21 // - Second loop computes exp for every element in v after subtracting
22 // the max of the softmax dim it belongs to.
23 // - Third loop computes the sum over the softmax dim.
24 // - Final loop computes softmax for every element in v.
25
26 // LogSoftmax is computed as follows:
27 // log_softmax(vi) = log(softmax(vi))
28 // = vi - log(sum(exp(vi)))
29 //
30 // Using the same max trick as above:
31 // log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
32 //
33 // This is implemented as 5 loopnests:
34 // - First loop computes the max over the softmax dim.
35 // - Second loop computes exp for every element in v after subtracting
36 // the max of the softmax dim it belongs to.
37 // - Third loop computes the sum over the softmax dim.
38 // - Fourth loop computes log for every element in the sum.
39 // - Final loop computes the log_softmax for every element in v.
40
41 TORCH_INTERNAL_ASSERT(inputs.size() == 3);
42
43 // We do not handle None for dims (input 1) because that is supposed to
44 // be deprecated.
45 TORCH_INTERNAL_ASSERT(std::get_if<int64_t>(&inputs[1]));
46 int64_t rank = valueShape(inputs[0]).size();
47 size_t softmax_dim =
48 normalizeAndCheckIndex(std::get<int64_t>(inputs[1]), rank);
49 std::vector<ExprHandle> non_softmax_dims;
50 for (size_t i = 0; i < outputShape.size(); ++i) {
51 if (i != softmax_dim) {
52 non_softmax_dims.push_back(outputShape[i]);
53 }
54 }
55
56 // Softmax implementation includes two reductions, one to find the max and
57 // the other to calculate the sum along the softmax dim. These reductions
58 // will have the softmax dimension as the inner most loop. So, the innermost
59 // index in the indices will refer to the softmax dimension.
60
61 // Update the indices by moving the softmax dimension index to the
62 // appropriate position.
63 auto move_softmax_dim_index_to_pos = [&](const ParameterList& indices) {
64 std::vector<ExprHandle> new_indices;
65 for (const auto& ind : indices) {
66 new_indices.push_back(ind);
67 }
68 for (size_t i = softmax_dim; i < indices.size() - 1; ++i) {
69 new_indices[i + 1] = indices[i];
70 }
71 new_indices[softmax_dim] = indices[indices.size() - 1];
72 return new_indices;
73 };
74
75 // Remove the index corresponding to the softmax dimension.
76 auto remove_softmax_dim_index = [&](const ParameterList& indices) {
77 std::vector<ExprHandle> new_indices;
78 for (size_t i = 0; i < indices.size(); ++i) {
79 if (i != softmax_dim) {
80 new_indices.push_back(indices[i]);
81 }
82 }
83 return new_indices;
84 };
85
86 auto convert_indices_to_expr_handle = [&](const ParameterList& indices) {
87 std::vector<ExprHandle> new_indices(indices.size());
88 for (size_t i = 0; i < indices.size(); ++i) {
89 new_indices[i] = indices[i];
90 }
91 return new_indices;
92 };
93
94 auto inp_buf = std::get<BufHandle>(inputs[0]);
95
96 auto dtype = inp_buf.dtype();
97 if (auto d = std::get_if<int64_t>(&inputs[2])) {
98 dtype = ToDtype(static_cast<ScalarType>(*d));
99 }
100
101 auto max = Reduce(
102 "aten_softmax_max",
103 non_softmax_dims,
104 std::nullopt,
105 Maximum(dtype),
106 [&](ParameterList& indices) {
107 return tensorOrConstant(
108 inputs[0], move_softmax_dim_index_to_pos(indices));
109 },
110 {outputShape[softmax_dim]});
111 auto e = Compute(
112 "aten_softmax_exp",
113 outputShape,
114 std::nullopt,
115 [&](ParameterList& indices) {
116 auto inp = tensorOrConstant(
117 inputs[0], convert_indices_to_expr_handle(indices));
118 return exp(inp - max.load(remove_softmax_dim_index(indices)));
119 });
120 auto sum = Reduce(
121 "aten_softmax_sum",
122 non_softmax_dims,
123 std::nullopt,
124 Sum(),
125 [&](ParameterList& indices) {
126 return e.load(move_softmax_dim_index_to_pos(indices));
127 },
128 {outputShape[softmax_dim]});
129 if (!log_softmax) {
130 auto result = Compute(
131 "aten_softmax", outputShape, std::nullopt, [&](ParameterList& indices) {
132 return e.load(indices) / sum.load(remove_softmax_dim_index(indices));
133 });
134 return Tensor(
135 result.buf(),
136 alloc<tensorexpr::Block>(std::vector<StmtPtr>(
137 {max.stmt(), e.stmt(), sum.stmt(), result.stmt()})));
138 }
139
140 auto log_sum = Compute(
141 "aten_softmax_log_sum",
142 non_softmax_dims,
143 std::nullopt,
144 [&](ParameterList& indices) { return log(sum.load(indices)); });
145 auto result = Compute(
146 "aten_log_softmax",
147 outputShape,
148 std::nullopt,
149 [&](ParameterList& indices) {
150 auto inp = tensorOrConstant(
151 inputs[0], convert_indices_to_expr_handle(indices));
152 auto non_softmax_indices = remove_softmax_dim_index(indices);
153 return inp - max.load(non_softmax_indices) -
154 log_sum.load(non_softmax_indices);
155 });
156 return Tensor(
157 result.buf(),
158 alloc<tensorexpr::Block>(std::vector<StmtPtr>(
159 {max.stmt(), e.stmt(), sum.stmt(), log_sum.stmt(), result.stmt()})));
160 }
161
162 } // namespace torch::jit::tensorexpr
163