xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
3 
4 namespace torch::jit::mobile {
for_each_tensor_in_ivalue(const c10::IValue & iv,std::function<void (const::at::Tensor &)> const & func)5 void for_each_tensor_in_ivalue(
6     const c10::IValue& iv,
7     std::function<void(const ::at::Tensor&)> const& func) {
8   const bool is_leaf_type = iv.isString() || iv.isNone() || iv.isScalar() ||
9       iv.isDouble() || iv.isInt() || iv.isBool() || iv.isDevice() ||
10       iv.isIntList() || iv.isDoubleList() || iv.isBoolList();
11   if (is_leaf_type) {
12     // Do Nothing.
13     return;
14   }
15 
16   if (iv.isTensor()) {
17     func(iv.toTensor());
18   } else if (iv.isTuple()) {
19     c10::intrusive_ptr<at::ivalue::Tuple> tup_ptr = iv.toTuple();
20     for (const auto& e : tup_ptr->elements()) {
21       for_each_tensor_in_ivalue(e, func);
22     }
23   } else if (iv.isList()) {
24     c10::List<c10::IValue> l = iv.toList();
25     for (auto&& i : l) {
26       c10::IValue item = i;
27       for_each_tensor_in_ivalue(item, func);
28     }
29   } else if (iv.isGenericDict()) {
30     c10::Dict<c10::IValue, c10::IValue> dict = iv.toGenericDict();
31     for (auto& it : dict) {
32       for_each_tensor_in_ivalue(it.value(), func);
33     }
34   } else {
35     AT_ERROR("Unhandled type of IValue. Got ", iv.tagKind());
36   }
37 }
38 } // namespace torch::jit::mobile
39