xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/constants.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/ivalue.h>
3 #include <ATen/core/jit_type.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/jit/frontend/source_range.h>
6 #include <torch/csrc/jit/ir/scope.h>
7 
8 // helpers for handling constants in the IR
9 // - create constant nodes from ints, floats, complex, intlist, Tensors, and
10 // other types
11 // - implement primitive constant ops.
12 
13 namespace torch::jit {
14 
15 using ::c10::IValue;
16 
17 struct Graph;
18 struct Value;
19 
20 // thrown when insertConstant cannot encode the IValue into a graph
21 struct TORCH_API constant_not_supported_error : public std::runtime_error {
22   using runtime_error::runtime_error;
23 };
24 
25 TORCH_API Value* insertConstant(
26     Graph& g,
27     const IValue& val,
28     std::optional<SourceRange> loc = std::nullopt,
29     std::optional<ScopePtr> scope = std::nullopt);
30 
31 // note: prefer g.insertConsant(val, loc) which does exactly the same thing
32 // this function is only declared/defined here because its implementation is
33 // closely related to the implementation of prim::Constant that is also in
34 // constants.cpp.
35 //
36 // returns a std::nullopt if the IValue kind cannot be inserted as a constant
37 TORCH_API std::optional<Value*> tryInsertConstant(
38     Graph& g,
39     const IValue& val,
40     std::optional<SourceRange> loc = std::nullopt,
41     std::optional<ScopePtr> scope = std::nullopt);
42 
43 ////////////////////////////////////////////////////////////////////////////////
44 // Helper for retrieving constants
45 ////////////////////////////////////////////////////////////////////////////////
46 
47 // attempt to convert a (possibly constant) Value* into an interpreter value
48 // (IValue). returns std::nullopt if the Value* was not constant
49 TORCH_API std::optional<IValue> toIValue(const Value* v);
50 
51 // if a value is a constant then try to turn into type T using the
52 // same rules as the interpreter
53 template <typename T>
constant_as(const Value * v)54 std::optional<T> constant_as(const Value* v) {
55   if (auto ivalue = toIValue(v)) {
56     return ivalue->to<T>();
57   }
58   return std::nullopt;
59 }
60 } // namespace torch::jit
61