修改ONNX模型节点

发布于:2023-09-15 ⋅ 阅读:(108) ⋅ 点赞:(0)

网上几乎所有的相关帖子都没注意到一个问题,就是要满足拓扑排序,不满足拓扑排序后面的check_model是会报错的。(其实就是要满足节点的先后顺序)

以修改ViT中的Concat节点为例:

import onnx

onnx_model = onnx.load("./vit.onnx")
graph = onnx_model.graph
node  = graph.node

orig_len = len(node)
node_name = '/Concat'
node_index = -1
for i in range(len(node)):
    if node_name == node[i].name:
        node_index = i

inputs_list = []
outputs_list = []
for item in node[node_index].input:
    inputs_list.append(item)
for item in node[node_index].output:
    outputs_list.append(item)

offset = len(inputs_list)+1

for i in range(len(graph.node)):
    T_index = len(node)-i
    if T_index == len(node):
        T_index = T_index-1
    if T_index>node_index and T_index<orig_len:
        T_node = node[T_index]
        if (T_index+offset) < len(node):
            graph.node.remove(node[T_index+offset])
        graph.node.insert(T_index+offset, T_node)

for i in range(len(inputs_list)):
    index = node_index+i
    attr = onnx.helper.make_attribute('perm', [1, 0, 2])
    new_scale_node = onnx.helper.make_node(
        "Transpose",
        inputs=[inputs_list[i]],
        outputs=[inputs_list[i]+'_add']
    )
    old_node = node[index]
    graph.node.remove(old_node)
    graph.node.insert(index, new_scale_node)
    node[index].attribute.insert(0, attr)
    node[index].name = 'Transpose_'+str(index)

index = node_index+len(inputs_list)
attr = onnx.helper.make_attribute('axis', 0)  #添加属性
new_scale_node = onnx.helper.make_node(
    'Concat',
    inputs=[item+'_add' for item in inputs_list],
    outputs=['output_changed']
)  # 新建新节点
old_scale_node = node[index]
graph.node.remove(old_scale_node)  # 删除旧节点
graph.node.insert(index, new_scale_node)  # 插入新节点
node[index].attribute.insert(0, attr)
node[index].name = node_name

index = node_index+len(inputs_list)+1
attr = onnx.helper.make_attribute('perm', [1, 0, 2])
new_scale_node = onnx.helper.make_node(
    "Transpose",
    inputs=['output_changed'],
    outputs=[item for item in outputs_list]
)
old_node = node[index]
graph.node.remove(old_node)
graph.node.insert(index, new_scale_node)
node[index].attribute.insert(0, attr)
node[index].name = 'Transpose_'+str(index)

graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
info_model = onnx.helper.make_model(graph)
# onnx_model = onnx.shape_inference.infer_shapes(info_model)
 
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, './changed_vit.onnx')