xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/softmax.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/operators/softmax.h>
2 
3 namespace torch::jit::tensorexpr {
4 
5 using namespace torch::jit::tensorexpr;
6 
computeSoftmax(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,bool log_softmax)7 Tensor computeSoftmax(
8     const std::vector<ArgValue>& inputs,
9     const std::vector<ExprHandle>& outputShape,
10     const std::vector<ExprHandle>& outputStrides,
11     bool log_softmax) {
12   // Softmax is computed as follows:
13   //    softmax(vi) = exp(vi) / sum(exp(vi))
14   //
15   // In order to avoid overflow issues due to exp of a large number, we
16   // subtract the max of that dim before computing exp.
17   //    softmax(vi) = exp(vi - max(vi)) / sum(exp(vi - max(vi)))
18   //
19   // This is implemented as 4 loopnests:
20   //   - First loop computes the max over the softmax dim.
21   //   - Second loop computes exp for every element in v after subtracting
22   //     the max of the softmax dim it belongs to.
23   //   - Third loop computes the sum over the softmax dim.
24   //   - Final loop computes softmax for every element in v.
25 
26   // LogSoftmax is computed as follows:
27   //    log_softmax(vi) = log(softmax(vi))
28   //                    = vi - log(sum(exp(vi)))
29   //
30   // Using the same max trick as above:
31   //    log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
32   //
33   // This is implemented as 5 loopnests:
34   //   - First loop computes the max over the softmax dim.
35   //   - Second loop computes exp for every element in v after subtracting
36   //     the max of the softmax dim it belongs to.
37   //   - Third loop computes the sum over the softmax dim.
38   //   - Fourth loop computes log for every element in the sum.
39   //   - Final loop computes the log_softmax for every element in v.
40 
41   TORCH_INTERNAL_ASSERT(inputs.size() == 3);
42 
43   // We do not handle None for dims (input 1) because that is supposed to
44   // be deprecated.
45   TORCH_INTERNAL_ASSERT(std::get_if<int64_t>(&inputs[1]));
46   int64_t rank = valueShape(inputs[0]).size();
47   size_t softmax_dim =
48       normalizeAndCheckIndex(std::get<int64_t>(inputs[1]), rank);
49   std::vector<ExprHandle> non_softmax_dims;
50   for (size_t i = 0; i < outputShape.size(); ++i) {
51     if (i != softmax_dim) {
52       non_softmax_dims.push_back(outputShape[i]);
53     }
54   }
55 
56   // Softmax implementation includes two reductions, one to find the max and
57   // the other to calculate the sum along the softmax dim. These reductions
58   // will have the softmax dimension as the inner most loop. So, the innermost
59   // index in the indices will refer to the softmax dimension.
60 
61   // Update the indices by moving the softmax dimension index to the
62   // appropriate position.
63   auto move_softmax_dim_index_to_pos = [&](const ParameterList& indices) {
64     std::vector<ExprHandle> new_indices;
65     for (const auto& ind : indices) {
66       new_indices.push_back(ind);
67     }
68     for (size_t i = softmax_dim; i < indices.size() - 1; ++i) {
69       new_indices[i + 1] = indices[i];
70     }
71     new_indices[softmax_dim] = indices[indices.size() - 1];
72     return new_indices;
73   };
74 
75   // Remove the index corresponding to the softmax dimension.
76   auto remove_softmax_dim_index = [&](const ParameterList& indices) {
77     std::vector<ExprHandle> new_indices;
78     for (size_t i = 0; i < indices.size(); ++i) {
79       if (i != softmax_dim) {
80         new_indices.push_back(indices[i]);
81       }
82     }
83     return new_indices;
84   };
85 
86   auto convert_indices_to_expr_handle = [&](const ParameterList& indices) {
87     std::vector<ExprHandle> new_indices(indices.size());
88     for (size_t i = 0; i < indices.size(); ++i) {
89       new_indices[i] = indices[i];
90     }
91     return new_indices;
92   };
93 
94   auto inp_buf = std::get<BufHandle>(inputs[0]);
95 
96   auto dtype = inp_buf.dtype();
97   if (auto d = std::get_if<int64_t>(&inputs[2])) {
98     dtype = ToDtype(static_cast<ScalarType>(*d));
99   }
100 
101   auto max = Reduce(
102       "aten_softmax_max",
103       non_softmax_dims,
104       std::nullopt,
105       Maximum(dtype),
106       [&](ParameterList& indices) {
107         return tensorOrConstant(
108             inputs[0], move_softmax_dim_index_to_pos(indices));
109       },
110       {outputShape[softmax_dim]});
111   auto e = Compute(
112       "aten_softmax_exp",
113       outputShape,
114       std::nullopt,
115       [&](ParameterList& indices) {
116         auto inp = tensorOrConstant(
117             inputs[0], convert_indices_to_expr_handle(indices));
118         return exp(inp - max.load(remove_softmax_dim_index(indices)));
119       });
120   auto sum = Reduce(
121       "aten_softmax_sum",
122       non_softmax_dims,
123       std::nullopt,
124       Sum(),
125       [&](ParameterList& indices) {
126         return e.load(move_softmax_dim_index_to_pos(indices));
127       },
128       {outputShape[softmax_dim]});
129   if (!log_softmax) {
130     auto result = Compute(
131         "aten_softmax", outputShape, std::nullopt, [&](ParameterList& indices) {
132           return e.load(indices) / sum.load(remove_softmax_dim_index(indices));
133         });
134     return Tensor(
135         result.buf(),
136         alloc<tensorexpr::Block>(std::vector<StmtPtr>(
137             {max.stmt(), e.stmt(), sum.stmt(), result.stmt()})));
138   }
139 
140   auto log_sum = Compute(
141       "aten_softmax_log_sum",
142       non_softmax_dims,
143       std::nullopt,
144       [&](ParameterList& indices) { return log(sum.load(indices)); });
145   auto result = Compute(
146       "aten_log_softmax",
147       outputShape,
148       std::nullopt,
149       [&](ParameterList& indices) {
150         auto inp = tensorOrConstant(
151             inputs[0], convert_indices_to_expr_handle(indices));
152         auto non_softmax_indices = remove_softmax_dim_index(indices);
153         return inp - max.load(non_softmax_indices) -
154             log_sum.load(non_softmax_indices);
155       });
156   return Tensor(
157       result.buf(),
158       alloc<tensorexpr::Block>(std::vector<StmtPtr>(
159           {max.stmt(), e.stmt(), sum.stmt(), log_sum.stmt(), result.stmt()})));
160 }
161 
162 } // namespace torch::jit::tensorexpr
163