1 #pragma once
2 #include <ATen/core/ivalue.h>
3 #include <ATen/core/jit_type.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/jit/frontend/source_range.h>
6 #include <torch/csrc/jit/ir/scope.h>
7
8 // helpers for handling constants in the IR
9 // - create constant nodes from ints, floats, complex, intlist, Tensors, and
10 // other types
11 // - implement primitive constant ops.
12
13 namespace torch::jit {
14
15 using ::c10::IValue;
16
17 struct Graph;
18 struct Value;
19
20 // thrown when insertConstant cannot encode the IValue into a graph
21 struct TORCH_API constant_not_supported_error : public std::runtime_error {
22 using runtime_error::runtime_error;
23 };
24
25 TORCH_API Value* insertConstant(
26 Graph& g,
27 const IValue& val,
28 std::optional<SourceRange> loc = std::nullopt,
29 std::optional<ScopePtr> scope = std::nullopt);
30
31 // note: prefer g.insertConsant(val, loc) which does exactly the same thing
32 // this function is only declared/defined here because its implementation is
33 // closely related to the implementation of prim::Constant that is also in
34 // constants.cpp.
35 //
36 // returns a std::nullopt if the IValue kind cannot be inserted as a constant
37 TORCH_API std::optional<Value*> tryInsertConstant(
38 Graph& g,
39 const IValue& val,
40 std::optional<SourceRange> loc = std::nullopt,
41 std::optional<ScopePtr> scope = std::nullopt);
42
43 ////////////////////////////////////////////////////////////////////////////////
44 // Helper for retrieving constants
45 ////////////////////////////////////////////////////////////////////////////////
46
47 // attempt to convert a (possibly constant) Value* into an interpreter value
48 // (IValue). returns std::nullopt if the Value* was not constant
49 TORCH_API std::optional<IValue> toIValue(const Value* v);
50
51 // if a value is a constant then try to turn into type T using the
52 // same rules as the interpreter
53 template <typename T>
constant_as(const Value * v)54 std::optional<T> constant_as(const Value* v) {
55 if (auto ivalue = toIValue(v)) {
56 return ivalue->to<T>();
57 }
58 return std::nullopt;
59 }
60 } // namespace torch::jit
61