【torchserve】农业小模型部署

发布于:2025-04-09 ⋅ 阅读:(40) ⋅ 点赞:(0)

目录

TorchServe架构

TorchScript 格式部署

torch.compile 函数

1-环境配置

 2-示例

image_classifier示例(densenet161)

3-农业小模型部署

1. 文件准备【水稻病害分类模型】

2. 使用执行文件  注册模型

3. 启动服务

4. 测试

5.封装工具api.py


TorchServe架构

TorchScript 格式部署

在使用 TorchServe 部署 PyTorch 模型时,您可以选择直接使用 Python 模型(即“普通注册”)或将模型转换为 TorchScript 格式。这两种方法的主要区别包括:​

  1. 模型格式与序列化

  2. 部署环境

    • 普通注册:​需要在部署环境中安装完整的 Python 解释器和相关依赖库。​

    • TorchScript 方法:​转换后的模型可以在不依赖 Python 的环境中运行,适用于 C++ 等其他语言的运行时环境。 ​Stack Overflow+1Hugging Face+1

torch.compile 函数

用于通过即时(JIT)编译优化模型或函数的运行效率。​该函数接受多个参数,允许用户根据需求调整编译行为。以下是主要参数及其作用:​CSDN+2CSDN+2听歌语+2掘金

  • model:​(可选)需要优化的模型或函数,通常是一个 torch.nn.Module 实例或可调用的 Python 函数。​

  • fullgraph:​(布尔值)指示是否要求编译整个计算图。如果设置为 True,则强制编译完整的计算图;如果为 False,则允许在遇到无法编译的代码时将计算图拆分为子图。默认值为 False。 ​掘金

  • dynamic:​(布尔值)指示是否支持动态形状。如果设置为 True,编译的模型可以处理输入张量的动态变化,但可能会影响性能。默认值为 False。 ​掘金

  • backend:​(字符串或可调用对象)指定用于编译的后端。默认值为 "inductor",这是 PyTorch 内置的编译器,能够将计算图转换为优化的机器代码。用户可以根据需求选择其他后端,如 "eager"(不进行额外优化)等。 ​CSDN+2掘金+2CSDN+2

  • mode:​(字符串)控制编译模式,影响编译时间和运行时性能的权衡。可选值包括:​

    • "default":​适用于大多数情况,编译速度和运行时性能之间取得平衡。​

    • "reduce-overhead":​适用于小型模型,可能增加编译时间,但减少运行时开销。​

    • "max-autotune":​编译时间最长,但可能提供最佳的运行时性能。 ​CSDN+1Hugging Face+1

  • options:​(字典)向后端传递的额外选项,允许用户进一步自定义编译行为。​

  • disable:​(布尔值)如果设置为 True,则禁用 torch.compile 的功能,使其成为一个空操作,主要用于测试目的。默认值为 False


1-环境配置

(base) liguangzhen@ubuntu:~$ java -version
java version "17.0.12" 2024-07-16 LTS
Java(TM) SE Runtime Environment (build 17.0.12+8-LTS-286)
Java HotSpot(TM) 64-Bit Server VM (build 17.0.12+8-LTS-286, mixed mode, sharing)

参考: 入门指南

conda create -n  torchserve  python=3.8 -y

conda activate torchserve

python ./ts_scripts/install_dependencies.py --cuda=cu121

conda install torchserve torch-model-archiver torch-workflow-archiver -c pytorch

pip install nvgpu

git clone https://github.com/pytorch/serve.git

 2-示例

torchServe==0.12

image_classifier示例(densenet161)

1. 创建目录存储模型

2. 进入文件夹创建一个目录来存储您的模型,下载权重。

wget https://download.pytorch.org/models/densenet161-8d451a50.pth

3. 创建执行文件或者使用指令

# densenet161.sh
torch-model-archiver --model-name densenet161 --version 1.0 \
    --model-file /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/densenet_161/model.py \
    --serialized-file /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/densenet_161/densenet161-8d451a50.pth \
    --export-path /home74/liguangzhen/Project/Serve/serve/model_store \
    --extra-files /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/index_to_name.json \
    --handler image_classifier \
    -f


# torchserve --start --ncs --model-store model_store --models densenet161.mar

# curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg

4. 启动serve

# 启动
torchserve --start --ncs --model-store model_store --models densenet161.mar

# 测试
curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg

# 停止
torchserve --stop

执行路径:(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ torchserve --start --ncs --model-store model_store --models densenet161.mar

问题:

(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg
{
  "code": 400,
  "type": "InvalidKeyException",
  "message": "Token Authorization failed. Token either incorrect, expired, or not provided correctly"
}

# 该错误表明 TorchServe 的令牌授权功能未正确配置,导致 API 请求被拒绝。

①找到修改 config.properties 文件

        路径:

        修改内容为: 

inference_address=http://127.0.0.1:8080
management_address=http://127.0.0.1:8081

number_of_netty_threads=32
job_queue_size=1000

vmargs=-Xmx4g -XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError
prefer_direct_buffer=True

default_response_timeout=300
unregister_model_timeout=300
install_py_dep_per_model=true

# 添加
disable_token_authorization=true

②设置环境变:

export TS_DISABLE_TOKEN_AUTHORIZATION=true

③修改指令:

torchserve --start --ncs --model-store model_store --models densenet161.mar --disable-token-auth

问题解决! 

5. 测试:

(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ curl http://127.0.0.1:8080/predictions/densenet161 -T /home74/liguangzhen/Project/Serve/serve/examples/image_classifier/kitten.jpg
{
  "tabby": 0.465851366519928,
  "tiger_cat": 0.4652356207370758,
  "Egyptian_cat": 0.06615941971540451,
  "lynx": 0.0012939950684085488,
  "plastic_bag": 0.00022918518516235054
}
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ 

3-农业小模型部署

1. 文件准备【水稻病害分类模型】

①  model.py 模型文件

注:只存放模型网络结构!

②  efficient_rice_disease_classification.pth 权重文件

注:权重保存方法  (避免结构对不上)

torch.save(model.state_dict(), "efficient_rice_disease_classification.pth")

③  index_to_name.json  分类索引文件

{
    "0": "Bacteriablight",
    "1": "Blast",
    "2": "Brownspot",
    "3": "Tungro"
}

④ rice_disease_classification_handler.py  负责定义模型的加载、预处理、推理和后处理等逻辑。

import json
from ts.torch_handler.image_classifier import ImageClassifier

class riceDiseaseClassificationHandler(ImageClassifier):
    # 预处理输入数据
    def postprocess(self, data):
        results = []
        for prediction in data:
            probs = prediction.tolist()
            max_prob = max(probs)
            max_index = probs.index(max_prob)
            label = self.mapping[str(max_index)] if self.mapping else str(max_index)
            results.append({
                "label": label,
                "confidence": f"{max_prob * 10:.2f}%"  # 百分比格式,保留两位小数
            })
        return [json.dumps(results[0], ensure_ascii=False)]  # 返回 JSON 字符串

⑤  blast_rice_test.jpg 测试图片

2. 使用执行文件  注册模型

生成  rice_disease_classification.mar  文件。

# rice_disease_classification.sh

torch-model-archiver \
    --model-name rice_disease_classification \
    --version 1.0 \
    --model-file /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/model.py \
    --serialized-file /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/efficient_rice_disease_classification.pth \
    --handler /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/rice_disease_classification_handler.py \
    --extra-files /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/index_to_name.json \
    --export-path /home74/liguangzhen/Project/Serve/serve/model_store \
    -f

3. 启动服务

装 包: pip install timm

torchserve --start --ncs --model-store model_store --models rice_disease_classification.mar --disable-token-auth

4. 测试

(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ curl http://127.0.0.1:8080/predictions/rice_disease_classification -T /home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/blast_rice_test.jpg
{
"label": "Blast", "confidence": "59.80%"
}
(torchserve) liguangzhen@ubuntu:~/Project/Serve/serve$ 

测试成功!

5.封装工具api.py

import requests
import subprocess
import time
import os



# 启动 TorchServe 服务并注册模型
def start_torchserve(model_store: str, model_name: str, model_file: str):

    # 构建启动命令
    cmd = [
        "torchserve",
        "--start",
        "--ncs",
        "--model-store", model_store,
        "--models", f"{model_name}={model_file}",
        "--disable-token-auth"
    ]

    # 打开 os.devnull 以抑制输出
    with open(os.devnull, 'w') as devnull:
        try:
            # 使用 subprocess.run 启动命令,并将 stdout 和 stderr 重定向到 devnull
            subprocess.run(cmd, stdout=devnull, stderr=devnull, check=True)
            print(f"TorchServe 已启动,并注册了模型 {model_name}")
        except subprocess.CalledProcessError as e:
            print(f"启动 TorchServe 时出错: {e}")

# 停止 TorchServe 服务
def stop_torchserve():
    try:
        subprocess.run(["torchserve", "--stop"], check=True)
        print("TorchServe 已停止")
    except subprocess.CalledProcessError as e:
        print(f"停止 TorchServe 时出错: {e}")


def seed_classification(img_path: str) -> dict:
    """
    种子分类函数,调用 TorchServe 提供的模型服务进行预测。

    参数:
        img_path (str): 本地图片的路径,支持 .jpg 或 .png 格式。

    返回:
        dict: 包含预测标签和置信度的字典。
    """
    # 确保文件格式正确
    if not (img_path.endswith('.jpg') or img_path.endswith('.png')):
        raise ValueError("图片格式必须为 .jpg 或 .png")

    # TorchServe 模型服务的 URL
    url = "http://127.0.0.1:8080/predictions/seed_classification"

    # 读取图片并发送 POST 请求
    with open(img_path, 'rb') as img_file:
        files = {'data': img_file}
        response = requests.post(url, files=files)

    # 检查响应状态
    if response.status_code == 200:
        return response.json()
    else:
        raise RuntimeError(f"请求失败,状态码: {response.status_code}, 信息: {response.text}")

def rice_disease_classification(img_path: str) -> dict:
    """
    农作物病害分类函数,调用 TorchServe 提供的模型服务进行预测。

    参数:
        img_path (str): 本地图片的路径,支持 .jpg 或 .png 格式。

    返回:
        dict: 包含预测标签和置信度的字典。
    """
    # 确保文件格式正确
    if not (img_path.endswith('.jpg') or img_path.endswith('.png')):
        raise ValueError("图片格式必须为 .jpg 或 .png")

    # TorchServe 模型服务的 URL
    url = "http://127.0.0.1:8080/predictions/rice_disease_classification"

    # 读取图片并发送 POST 请求
    with open(img_path, 'rb') as img_file:
        files = {'data': img_file}
        response = requests.post(url, files=files)

    # 检查响应状态
    if response.status_code == 200:
        return response.json()
    else:
        raise RuntimeError(f"请求失败,状态码: {response.status_code}, 信息: {response.text}")


if __name__ == "__main__":
    # 定义模型存储路径和模型名称
    model_store = "/home74/liguangzhen/Project/Serve/serve/model_store"
    model_name = "rice_disease_classification"
    model_file = f"{model_name}.mar"
    # 启动 TorchServe
    start_torchserve(model_store, model_name, model_file)
    # 等待一段时间,确保服务启动
    time.sleep(5)
    # 此处可以添加其他操作,例如发送推理请求
    result = rice_disease_classification('/home74/liguangzhen/Project/Serve/serve/task/rice_disease_classification/blast_rice_test.jpg')
    print(result)
    # 停止 TorchServe
    stop_torchserve()
    

测试成功!