xref: /aosp_15_r20/external/executorch/extension/pytree/aten_util/ivalue_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <ATen/ATen.h>
12 #include <ATen/core/TensorBody.h>
13 #include <ATen/core/ivalue.h>
14 #include <memory>
15 #include <utility>
16 
17 // patternlint-disable executorch-cpp-nostdinc
18 #include <vector>
19 
20 #include <executorch/extension/pytree/pytree.h>
21 
22 namespace executorch {
23 namespace extension {
24 
25 std::pair<
26     std::vector<at::Tensor>,
27     std::unique_ptr<::executorch::extension::pytree::TreeSpec<
28         ::executorch::extension::pytree::Empty>>>
29 flatten(const c10::IValue& data);
30 
31 c10::IValue unflatten(
32     const std::vector<at::Tensor>& tensors,
33     const std::unique_ptr<::executorch::extension::pytree::TreeSpec<
34         ::executorch::extension::pytree::Empty>>& tree_spec);
35 
36 bool is_same(
37     const std::vector<at::Tensor>& a,
38     const std::vector<at::Tensor>& b);
39 
40 bool is_same(const c10::IValue& lhs, const c10::IValue& rhs);
41 
42 } // namespace extension
43 } // namespace executorch
44 
45 namespace torch {
46 namespace executor {
47 namespace util {
48 // TODO(T197294990): Remove these deprecated aliases once all users have moved
49 // to the new `::executorch` namespaces.
50 using ::executorch::extension::flatten;
51 using ::executorch::extension::is_same;
52 using ::executorch::extension::unflatten;
53 } // namespace util
54 } // namespace executor
55 } // namespace torch
56