文章目录
一、导出为dynamic shape
1)函数讲解(函数导出、输出检查)
①torch.onnx.export
torch.onnx.export(
clip_model,
(tokens),
onnx_path,
verbose=True,
opset_version=18,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
(1)export_params:默认为true,表示导出的 ONNX 模型文件会包含模型的所有参数(如权重、偏置等)。而当设置为 False 时,导出的 ONNX 模型文件仅包含模型的计算图结构,不包含模型的参数。这意味着导出的 ONNX 文件会小很多,因为它没有存储大量的参数数据
(2)verbose:为true表示,将会输出大量打印日志信息
(3)do_constant_folding:一般为true,是一个布尔类型的参数,其作用是控制在导出 ONNX 模型时是否进行常量折叠优化从而提高推理性能。为TRUE开启常量折叠优化。在导出 ONNX 模型时,会对图中所有仅包含常量输入的操作进行预先计算,并用计算结果替换这些操作,以此简化计算图,减少模型的计算量和复杂度。
(4)input_names和output_names:输入、输出参数
(5)dynamic_axes:是一个字典,其键为输入或输出张量的名称,值也是一个字典,用于指定该张量中哪些维度是动态的。内层字典的键是维度索引(从 0 开始),值是一个字符串,用于标识这个动态维度,通常在 ONNX 运行时会使用这个标识来指定具体的维度大小
(6)opset_version:指定optset的版本
输入参数举例:
dynamic_axes = {
"x": {
0: "batch_size"},
"hint": {
0: "batch_size"},
"timesteps": {
0: "batch_size"},
"context": {
0: "batch_size", 1: "sequence_length"},
"output": {
0: "batch_size", 1: "hint_height", 2: "hint_width"}
}
dynamic_axes = {
"input_ids": {
1: "S"}, "last_hidden_state": {
1: "S"}}
dynamic_axes = {
"x": {
0: "latent"},
}
②误差检查
#onnx_path onnx文件目录
#input_dicts 输入参数
#torch_outputs 模型输出结果
def onnxruntime_check(onnx_path, input_dicts, torch_outputs):
onnx_model = onnx.load(onnx_path)
# onnx.checker.check_model(onnx_model)
sess = rt.InferenceSession(onnx_path)
# outputs = self.get_output_names()
# latent input
# data = np.zeros((4, 77), dtype=np.int32)
result = sess.run(None, input_dicts)
cnt = 0
for i in range(0, len(torch_outputs)):
ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)
cnt = cnt +1
if ret is False:
#print(f"onnxruntime_check {i} ret:{ret} result[i]:{result[i]} torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")
print("Error onnxruntime_check")
# import pdb; pdb.set_trace()
#print("cnt:", cnt)
2)代码展示
- 代码
import numpy as np
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3
import cv2
import datetime
from share import *
import config
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import os
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from onnx import shape_inference
import onnx_graphsurgeon as gs
import onnx
import onnxruntime as rt
def optimize(onnx_path, opt_onnx_path):
from onnxsim import simplify
model = onnx.load(onnx_path)
graph = gs.import_onnx(model)
print(f"{
onnx_path} simplify start !")
# self.info("init", graph)
model_simp, check = simplify(model)
# self.info("opt", gs.import_onnx(model_simp))
onnx.save(model_simp, opt_onnx_path, save_as_external_data=True)
assert check, "Simplified ONNX model could not be validated"
print(f"{
onnx_path} simplify done !")
def onnxruntime_check(onnx_path, input_dicts, torch_outputs):
onnx_model = onnx.load(onnx_path)
# onnx.checker.check_model(onnx_model)
sess = rt.InferenceSession(onnx_path)
# outputs = self.get_output_names()
# latent input
# data = np.zeros((4, 77), dtype=np.int32)
result = sess.run(None, input_dicts)
cnt = 0
for i in range(0, len(torch_outputs)):
ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)
cnt = cnt +1
if ret is False:
#print(f"onnxruntime_check {i} ret:{ret} result[i]:{result[i]} torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")
print("Error onnxruntime_check")
# import pdb; pdb.set_trace()
#print("cnt:", cnt)
class hackathon():
def initialize(self):
self.apply_canny = CannyDetector()
self.model = create_model('./models/cldm_v15.yaml').cpu()
self.model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cpu'))
# self.model.load_state_dict(load_state_dict('/home/player/ControlNet/models/control_sd15_canny.pth', location='cuda'))
self.model = self.model.cpu()
self.model.eval()
self.ddim_sampler = DDIMSampler(self.model)
hk = hackathon()
hk.initialize()
def export_clip_model():
clip_model = hk.model.cond_stage_model
import types
def forward(self, tokens):
outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == "hidden"
)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
clip_model.forward = types.MethodType(forward, clip_model)
onnx_path = "./onnx/CLIP.onnx"
tokens = torch.zeros(1, 77, dtype=torch.int32)
input_names = ["input_ids"]
output_names = ["last_hidden_state"]
dynamic_axes = {
"input_ids": {
1: "S"}, "last_hidden_state": {
1: "S"}}
torch.onnx.export(
clip_model,
(tokens),
onnx_path,
verbose=True,
opset_version=18,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
print("======================= CLIP model export onnx done!")
# verify onnx model
output = clip_model(tokens)
input_dicts = {
"input_ids": tokens.numpy()}
onnxruntime_check(onnx_path, input_dicts, [output])
print("======================= CLIP onnx model verify done!")
# opt_onnx_path = "./onnx/CLIP.opt.onnx"
# optimize(onnx_path, opt_onnx_path)
def export_control_net_model():
control_net_model = hk.model.control_model
onnx_path = "./onnx/control_net_model.onnx"
def get_shape(B=1,S=64):
return [(B, 4, 32, 48),(B, 3, 256, 384),tuple([B])