1import json 2from typing import List 3 4from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node 5from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder 6from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode 7 8 9def serialize_extern_kernel_node( 10 extern_kernel_node: inductor_ExternKernelNode, 11) -> ExternKernelNode: 12 assert isinstance(extern_kernel_node.node, Node) 13 return ExternKernelNode( 14 name=extern_kernel_node.name, 15 node=extern_kernel_node.node, 16 ) 17 18 19def extern_node_json_serializer( 20 extern_kernel_nodes: List[inductor_ExternKernelNode], 21) -> str: 22 serialized_nodes = ExternKernelNodes( 23 nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] 24 ) 25 return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) 26