目录
image_classifier示例(densenet161)
TorchServe架构
TorchScript 格式部署
在使用 TorchServe 部署 PyTorch 模型时,您可以选择直接使用 Python 模型(即“普通注册”)或将模型转换为 TorchScript 格式。这两种方法的主要区别包括:
模型格式与序列化:
普通注册:直接使用 Python 定义的模型,通常以
.pth
文件形式保存模型权重。TorchScript 方法:利用 PyTorch 的 TorchScript 功能,将模型转换为中间表示(IR),生成一个可序列化的
.pt
文件。Amazon Web Services, Inc.+7PyTorch Developer Mailing List+7Medium+7
部署环境:
普通注册:需要在部署环境中安装完整的 Python 解释器和相关依赖库。
TorchScript 方法:转换后的模型可以在不依赖 Python 的环境中运行,适用于 C++ 等其他语言的运行时环境。 Stack Overflow+1Hugging Face+1
torch.compile
函数
用于通过即时(JIT)编译优化模型或函数的运行效率。该函数接受多个参数,允许用户根据需求调整编译行为。以下是主要参数及其作用:CSDN+2CSDN+2听歌语+2掘金
model
:(可选)需要优化的模型或函数,通常是一个torch.nn.Module
实例或可调用的 Python 函数。fullgraph
:(布尔值)指示是否要求编译整个计算图。如果设置为True
,则强制编译完整的计算图;如果为False
,则允许在遇到无法编译的代码时将计算图拆分为子图。默认值为False
。 掘金dynamic
:(布尔值)指示是否支持动态形状。如果设置为True
,编译的模型可以处理输入张量的动态变化,但可能会影响性能。默认值为False
。 掘金backend
:(字符串或可调用对象)指定用于编译的后端。默认值为"inductor"
,这是 PyTorch 内置的编译器,能够将计算图转换为优化的机器代码。用户可以根据需求选择其他后端,如"eager"
(不进行额外优化)等。 CSDN+2掘金+2CSDN+2mode
:(字符串)控制编译模式,影响编译时间和运行时性能的权衡。可选值包括:"default"
:适用于大多数情况,编译速度和运行时性能之间取得平衡。"reduce-overhead"
:适用于小型模型,可能增加编译时间,但减少运行时开销。"max-autotune"
:编译时间最长,但可能提供最佳的运行时性能。 CSDN+1Hugging Face+1
options
:(字典)向后端传递的额外选项,允许用户进一步自定义编译行为。disable
:(布尔值)如果设置为True
,则禁用torch.compile
的功能,使其成为一个空操作,主要用于测试目的。默认值为False
。
1-环境配置
(base) liguangzhen@ubuntu:~$ java -version
java version "17.0.12" 2024-07-16 LTS
Java(TM) SE Runtime Environment (build 17.0.12+8-LTS-286)
Java HotSpot(TM) 64-Bit Server VM (build 17.0.12+8-LTS-286, mixed mode, sharing)
参考: 入门指南
conda create -n torchserve python=3.8 -y
conda activate torchserve
python ./ts_scripts/install_dependencies.py --cuda=cu121
conda install torchserve torch-model-archiver torch-workflow-archiver -c pytorch
pip install nvgpu
git clone https://github.com/pytorch/serve.git
2-示例
torchServe==0.12
image_classifier示例(densenet161)
1. 创建目录存储模型
2. 进入文件夹创建一个目录来存储您的模型,下载权重。
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
3. 创建执行文件或者使用指令
# densenet161.sh
torch-model-archiver --model-name densenet161 --version 1.0 \
--model-file /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/densenet_161/model.py \
--serialized-file /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/densenet_161/densenet161-8d451a50.pth \
--export-path /home74/liguangzhen/Project/Serve/serve/model_store \
--extra-files /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/index_to_name.json \
--handler image_classifier \
-f
# torchserve --start --ncs --model-store model_store --models densenet161.mar
# curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg
4. 启动serve
# 启动
torchserve --start --ncs --model-store model_store --models densenet161.mar
# 测试
curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg
# 停止
torchserve --stop
执行路径:(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ torchserve --start --ncs --model-store model_store --models densenet161.mar
问题:
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg
{
"code": 400,
"type": "InvalidKeyException",
"message": "Token Authorization failed. Token either incorrect, expired, or not provided correctly"
}
# 该错误表明 TorchServe 的令牌授权功能未正确配置,导致 API 请求被拒绝。
①找到修改 config.properties
文件
路径:
修改内容为:
inference_address=http://127.0.0.1:8080
management_address=http://127.0.0.1:8081
number_of_netty_threads=32
job_queue_size=1000
vmargs=-Xmx4g -XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError
prefer_direct_buffer=True
default_response_timeout=300
unregister_model_timeout=300
install_py_dep_per_model=true
# 添加
disable_token_authorization=true
②设置环境变:
export TS_DISABLE_TOKEN_AUTHORIZATION=true
③修改指令:
torchserve --start --ncs --model-store model_store --models densenet161.mar --disable-token-auth
问题解决!
5. 测试:
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg
{
"tabby": 0.465851366519928,
"tiger_cat": 0.4652356207370758,
"Egyptian_cat": 0.06615941971540451,
"lynx": 0.0012939950684085488,
"plastic_bag": 0.00022918518516235054
}
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$
3-农业小模型部署
1. 文件准备【水稻病害分类模型】
① model.py 模型文件
注:只存放模型网络结构!
② efficient_rice_disease_classification.pth 权重文件
注:权重保存方法 (避免结构对不上)
torch.save(model.state_dict(), "efficient_rice_disease_classification.pth")
③ index_to_name.json 分类索引文件
{
"0": "Bacteriablight",
"1": "Blast",
"2": "Brownspot",
"3": "Tungro"
}
④ rice_disease_classification_handler.py 负责定义模型的加载、预处理、推理和后处理等逻辑。
import json
from ts.torch_handler.image_classifier import ImageClassifier
class riceDiseaseClassificationHandler(ImageClassifier):
# 预处理输入数据
def postprocess(self, data):
results = []
for prediction in data:
probs = prediction.tolist()
max_prob = max(probs)
max_index = probs.index(max_prob)
label = self.mapping[str(max_index)] if self.mapping else str(max_index)
results.append({
"label": label,
"confidence": f"{max_prob * 10:.2f}%" # 百分比格式,保留两位小数
})
return [json.dumps(results[0], ensure_ascii=False)] # 返回 JSON 字符串
⑤ blast_rice_test.jpg 测试图片
2. 使用执行文件 注册模型
生成 rice_disease_classification.mar 文件。
# rice_disease_classification.sh
torch-model-archiver \
--model-name rice_disease_classification \
--version 1.0 \
--model-file /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/model.py \
--serialized-file /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/efficient_rice_disease_classification.pth \
--handler /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/rice_disease_classification_handler.py \
--extra-files /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/index_to_name.json \
--export-path /home74/liguangzhen/Project/Serve/serve/model_store \
-f
3. 启动服务
装 包: pip install timm
torchserve --start --ncs --model-store model_store --models rice_disease_classification.mar --disable-token-auth
4. 测试
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ curl http://127.0.0.1:8080/predictions/rice_disease_classification -T /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/blast_rice_test.jpg
{
"label": "Blast", "confidence": "59.80%"
}
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$
测试成功!
5.封装工具api.py
import requests
import subprocess
import time
import os
# 启动 TorchServe 服务并注册模型
def start_torchserve(model_store: str, model_name: str, model_file: str):
# 构建启动命令
cmd = [
"torchserve",
"--start",
"--ncs",
"--model-store", model_store,
"--models", f"{model_name}={model_file}",
"--disable-token-auth"
]
# 打开 os.devnull 以抑制输出
with open(os.devnull, 'w') as devnull:
try:
# 使用 subprocess.run 启动命令,并将 stdout 和 stderr 重定向到 devnull
subprocess.run(cmd, stdout=devnull, stderr=devnull, check=True)
print(f"TorchServe 已启动,并注册了模型 {model_name}")
except subprocess.CalledProcessError as e:
print(f"启动 TorchServe 时出错: {e}")
# 停止 TorchServe 服务
def stop_torchserve():
try:
subprocess.run(["torchserve", "--stop"], check=True)
print("TorchServe 已停止")
except subprocess.CalledProcessError as e:
print(f"停止 TorchServe 时出错: {e}")
def seed_classification(img_path: str) -> dict:
"""
种子分类函数,调用 TorchServe 提供的模型服务进行预测。
参数:
img_path (str): 本地图片的路径,支持 .jpg 或 .png 格式。
返回:
dict: 包含预测标签和置信度的字典。
"""
# 确保文件格式正确
if not (img_path.endswith('.jpg') or img_path.endswith('.png')):
raise ValueError("图片格式必须为 .jpg 或 .png")
# TorchServe 模型服务的 URL
url = "http://127.0.0.1:8080/predictions/seed_classification"
# 读取图片并发送 POST 请求
with open(img_path, 'rb') as img_file:
files = {'data': img_file}
response = requests.post(url, files=files)
# 检查响应状态
if response.status_code == 200:
return response.json()
else:
raise RuntimeError(f"请求失败,状态码: {response.status_code}, 信息: {response.text}")
def rice_disease_classification(img_path: str) -> dict:
"""
农作物病害分类函数,调用 TorchServe 提供的模型服务进行预测。
参数:
img_path (str): 本地图片的路径,支持 .jpg 或 .png 格式。
返回:
dict: 包含预测标签和置信度的字典。
"""
# 确保文件格式正确
if not (img_path.endswith('.jpg') or img_path.endswith('.png')):
raise ValueError("图片格式必须为 .jpg 或 .png")
# TorchServe 模型服务的 URL
url = "http://127.0.0.1:8080/predictions/rice_disease_classification"
# 读取图片并发送 POST 请求
with open(img_path, 'rb') as img_file:
files = {'data': img_file}
response = requests.post(url, files=files)
# 检查响应状态
if response.status_code == 200:
return response.json()
else:
raise RuntimeError(f"请求失败,状态码: {response.status_code}, 信息: {response.text}")
if __name__ == "__main__":
# 定义模型存储路径和模型名称
model_store = "/home74/liguangzhen/Project/Serve/serve/model_store"
model_name = "rice_disease_classification"
model_file = f"{model_name}.mar"
# 启动 TorchServe
start_torchserve(model_store, model_name, model_file)
# 等待一段时间,确保服务启动
time.sleep(5)
# 此处可以添加其他操作,例如发送推理请求
result = rice_disease_classification('/home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/blast_rice_test.jpg')
print(result)
# 停止 TorchServe
stop_torchserve()
测试成功!