搭建自己的AI模型应用网站:JavaScript + Flask-Python + ONNX

发布于:2024-06-16 ⋅ 阅读:(18) ⋅ 点赞:(0)

1. 前言

本文作者以一个前端新手视角,部署自己的神经网络模型作为后端,搭建自己的网站实现应用的实战经历。目前实现的网页应用有:

欢迎大家试用感受,本文将以博客基于GAN的序列号码预测中训练的pytorch模型为例,进行后端和前端搭建,并构建网站,以下是最终成果展示。
在这里插入图片描述
网址:http://www.funsound.cn:5002

2. 相关内容

相关知识点和工具语言如下,希望读者有一定的了解

  • 腾讯云服务器
  • Html + JavaScript 进行UI设计
  • pytorch 模型,onnx 模型导出
  • python flask 后端
  • 多进程服务实现并发访问

3. 后端工作

3.1 pytorch 模型转 onnx 模型

ONNX 模型是通用的NN格式,采用onnx格式将在服务器cpu推理上速度更快。

# 实例化生成器模型
generator = Generator(input_dim, output_dim)

# 加载训练好的生成器模型权重
generator.load_state_dict(torch.load('models/generator_model.pth'))
generator.eval()  # 设置生成器为评估模式

# 导出模型为 ONNX 格式
generator.export_onnx('models/generator_model.onnx', (batch_size, input_dim))

加载onnx模型进行推理

# 加载 ONNX 模型
ort_session = ort.InferenceSession('models/generator_model.onnx')
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
input_noise = np.random.randn(batch_size, input_dim).astype(np.float32)
generated_numbers = ort_session.run([output_name], {input_name: input_noise})[0]

基于onnx推理的CP号码生成算法封装成 【generator. LOTTO_GENERATOR】

3.2 多进程onnx服务

网站访问往往是一个多路并发访问场景,面对众多用户的请求,送入待处理,后端采用多进程进行调度。

if __name__ == "__main__":
    from generator import LOTTO_GENERATOR # 我们的gan网络生成算法

    # 初始化worker数量
    nj = 4
    backends = [LOTTO_GENERATOR() for _ in range(nj)]
    workers = init_workers(nj=nj, backends=backends)

    # 获取并打印所有worker的状态
    res = get_workers_state(workers)
    print(res)

    # 提交100个任务
    worker_dir = "demo"
    for _ in range(100):
        task_id = generate_random_string(length=11)  # 生成长度为11的随机字符串作为task_id
        task_dir = f"{worker_dir}/{task_id}"  # 任务目录
        task_inp = generate_random_number_string(length=8)  # 生成长度为8的随机数字字符串作为任务输入
        task_prgs = f'{task_dir}/progress.txt'  # 任务进度文件路径
        task_rst = f'{task_dir}/result.txt'  # 任务结果文件路径
        
        os.system(f'mkdir -p {task_dir}')  # 创建任务目录
        params = {
            'task_id': task_id,
            'task_inp': task_inp,
            'task_prgs': task_prgs,
            'task_rst': task_rst
        }
        submit_task(workers=workers, params=params)  # 提交任务
        time.sleep(0.01)  # 等待10毫秒后提交下一个任务

注意代码中多进程服务处理用户请求采用异步方式,用户提交任务后获取task_id, 主进程不会阻塞, 用户根据task_id来追踪自己的任务进度(task_prgs)和结果(task_rst)。

其中调度方式根据子进程的忙碌情况决定,选取最闲的子进程处理用户请求

def submit_task(workers, params: dict):
    # 找到任务最少的worker
    min_task_worker = min(workers, key=lambda worker: worker.queue.qsize() + worker.working.value)
    min_task_worker.queue.put(params)  # 将任务提交到最少任务的worker队列中
    print(f'assign the task to worker-{min_task_worker.wid}'

3.3 基于Flask搭建http访问接口

我们的后端代码如下,例如我们的ip 是 100.100.123,端口试用5002,则构建了以下http访问接口:
http一般格式: 【http://IP地址:端口/路由】

  • http://100.100.123:5002/ 主页
  • http://100.100.123:5002/lotto 提交任务 【输入:用户幸运数字,输出:task_id】
  • http://100.100.123:5002/get_worker_state 子进程负载状态 【输入:task_id,输出:负载状态】
  • http://100.100.123:5002/get_task_prgs 任务完成进度 【输入:task_id,输出:任务进度】
  • http://100.100.123:5002/get_task_rst 任务结果 【输入:task_id,输出:任务结果】
from flask import Flask, jsonify,render_template,request
from generator import LOTTO_GENERATOR
from workers import *
import datetime
import json 

def get_now_time():
    current_time = datetime.datetime.now()
    return current_time.strftime('%Y-%m-%d %H:%M:%S')

def task_log(text,log_file="TASK.LOG"):
    with open(log_file,'a+') as f:
        print(text,file=f)


app = Flask(__name__)
USER_DIR = "user_data"
TASK_MAP = {}

"""
主页
"""
@app.route('/')
def index():
    return render_template('index.html')


@app.route('/lotto', methods=['POST'])
def predict():
    # 获取客户端信息
    ip = request.remote_addr
    data = request.get_json()

    task_id = ip + "_" + generate_random_string(20)
    user_id = ip
    task_inp = data['luck_num'] # 8位数字字符串 或者 空字符串
    task_dir = "%s/%s/%s" % (USER_DIR, user_id, task_id)
    task_prgs = f'{task_dir}/progress.txt'  # 任务进度文件路径
    task_rst = f'{task_dir}/result.txt'  # 任务结果文件路径
    task_log(f"TIME:{get_now_time()}")
    task_log(f"TASK_ID:{task_id}")
    task_log("")

    # 生成临时文件
    if not os.path.exists(task_dir): os.makedirs(task_dir)
    with open(task_prgs,'wt') as f:
        json.dump([0.0,'start'],f,indent=4)
    TASK_MAP[task_id] = {'task_dir': task_dir,
                         'task_prgs': task_prgs,
                         'task_rst': task_rst, }
    
    # 提交任务
    params = {
            'task_id': task_id,
            'task_inp': task_inp,
            'task_prgs': task_prgs,
            'task_rst': task_rst
        }
    submit_task(workers=workers, params=params)  # 提交任务
    return task_id


"""
获得引擎状态
"""
@app.route('/get_worker_state', methods=['GET'])
def get_worker_state():
    ip = request.remote_addr
    res = {}
    for worker in workers:
        res[worker.wid] = worker.queue.qsize() + worker.working.value
    return res


"""
获得任务进度
"""
@app.route('/get_task_prgs', methods=['POST'])
def get_task_prgs():
    ip = request.remote_addr
    data = request.get_json()
    task_id = data['task_id']
    if task_id not in TASK_MAP:
        return [-1, 'no such task id']
    else:
        task_prgs = TASK_MAP[task_id]['task_prgs']
        with open(task_prgs, 'rt') as f:
            content = json.load(f)
        return content

"""
获得任务结果
"""
@app.route('/get_task_rst', methods=['POST'])
def get_task_rst():
    ip = request.remote_addr
    data = request.get_json()
    task_id = data['task_id']
    if task_id not in TASK_MAP:
        return {}
    else:
        task_rst = TASK_MAP[task_id]['task_rst']
        with open(task_rst, 'rt') as f:
            content = json.load(f)
        return content

if __name__ == '__main__':

    # 初始化worker数量
    nj = 4
    backends = [LOTTO_GENERATOR() for _ in range(nj)]
    workers = init_workers(nj=nj, backends=backends)
    
    app.run(host='0.0.0.0', port=5002)

这样后端就搭建起来啦,这里有4个onnx 模型在后台监听

3.4 python客户端测试

import requests
import time
import json

# 定义服务端地址
server_url = 'http://110.110.123:5002' # 你的服务器和端口
headers = {'Content-Type': 'application/json'}

# 检查服务器 Worker 状态
def check_worker_status():
    response = requests.get(f'{server_url}/get_worker_state')
    if response.status_code == 200:
        worker_status = response.json()
        idle_workers = [wid for wid, status in worker_status.items() if status == 0]
        if idle_workers:
            print("Idle workers available:", idle_workers)
            return True
        else:
            print("No idle workers available.")
            return False
    else:
        print("Failed to get worker status.")
        return False

# 提交任务
def submit_task(json_data):
    if not check_worker_status():
        print("No idle workers available. Task submission failed.")
        return None

    response = requests.post(f'{server_url}/lotto', json=json_data)
    if response.status_code == 200:
        task_id = response.text
        print(f"Task submitted successfully. Task ID: {task_id}")
        return task_id
    else:
        print("Failed to submit task.")
        return None

# 轮询任务进度
def poll_task_progress(task_id):
    while True:
        json_data = {'task_id':task_id}
        response = requests.post(f'{server_url}/get_task_prgs', json=json_data)
        if response.status_code == 200:
            progress, text = response.json()
            print(f"Progress: {progress}, Status: {text}")
            if progress == 1:
                print("Task completed successfully.")
                return True
            elif progress == -1:
                print(f"Task failed: {text}")
                return False
        else:
            print("Failed to get task progress.")
            return False
        time.sleep(3)  # 每3秒查询一次

# 获取任务结果
def get_task_result(task_id):
    json_data = {'task_id':task_id}
    response = requests.post(f'{server_url}/get_task_rst', json=json_data)
    if response.status_code == 200:
        result = response.json()
        print("Task result:", result)
        return result
    else:
        print("Failed to get task result.")
        return None


# 主函数
def main():
    json_data = {'luck_num':""}
    # json_data = {'luck_num':"12345678"}

    # 提交TTS任务
    task_id = submit_task(json_data)
    if not task_id:
        return
        
    # 轮询任务进度
    if poll_task_progress(task_id):
        # 获取任务结果
        result = get_task_result(task_id)

if __name__ == "__main__":
    main()

在这里插入图片描述

访问成功

4. 前端工作

4.1 JavaScript 访问 http 函数

JavaScript 调用 http端口如下:

<script>

        /* 提交任务 */
        function submitTask() {
            var button = document.querySelector("button");
            button.disabled = true;
            button.innerText = "正在生成...";

            var useLuckyNumber = document.getElementById("use_lucky_number").checked;
            var luckInput = document.getElementById("luck_input");
            var luckNum = useLuckyNumber ? luckInput.value : "";
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/lotto", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var taskId = xhr.responseText;
                    checkProgress(taskId);
                } else if (xhr.readyState == 4) {
                    button.disabled = false;
                    button.innerText = "生成";
                    alert("任务提交失败,请重试。");
                }
            };
            xhr.send(JSON.stringify({luck_num: luckNum}));
        }

        /* 检查任务进度 */
        function checkProgress(taskId) {
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/get_task_prgs", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var response = JSON.parse(xhr.responseText);
                    var progress = response[0];
                    var status = response[1];
                    // document.getElementById("progress").innerText = "进度: " + progress + ", 状态: " + status;
                    if (progress == 1) {
                        getResult(taskId);
                    } else if (progress == -1) {
                        var button = document.querySelector("button");
                        button.disabled = false;
                        button.innerText = "生成";
                        alert("任务失败: " + status);
                    } else {
                        setTimeout(function() { checkProgress(taskId); }, 3000);
                    }
                }
            };
            xhr.send(JSON.stringify({task_id: taskId}));
        }

        /* 获取任务结果 */
        function getResult(taskId) {
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/get_task_rst", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var response = JSON.parse(xhr.responseText);
                    displayResult(response);
                    var button = document.querySelector("button");
                    button.disabled = false;
                    button.innerText = "生成";
                }
            };
            xhr.send(JSON.stringify({task_id: taskId}));
        }

        /* 显示任务结果 */
        function displayResult(response) {
            var frontNumbers = response.front_numbers;
            var backNumbers = response.back_numbers;
            var resultContainer = document.getElementById("result");
            resultContainer.innerHTML = ""; // 清空之前的结果

            for (var i = 0; i < frontNumbers.length; i++) {
                var lotterySet = document.createElement("div");
                lotterySet.className = "lottery-set";
                
                frontNumbers[i].forEach(function(number) {
                    var numberBall = document.createElement("div");
                    numberBall.className = "number-ball front-ball";
                    numberBall.innerText = number;
                    lotterySet.appendChild(numberBall);
                });

                backNumbers[i].forEach(function(number) {
                    var numberBall = document.createElement("div");
                    numberBall.className = "number-ball back-ball";
                    numberBall.innerText = number;
                    lotterySet.appendChild(numberBall);
                });

                resultContainer.appendChild(lotterySet);
            }
        }
    </script>

4.2 制作网页index.html

注意到Flask提供了网页渲染功能,这样我们可以设计我们的主页

@app.route('/')
def index():
    return render_template('index.html')

把上述JS脚本放入index.html 就可以访问后端服务啦,具体html的UI显示,由于代码量很大这里不与展示了,感兴趣同学可以根据上述python客户端的访问逻辑试用GPT为你编写index.html,手机端访问效果如下:

在这里插入图片描述

5. 最后

上述是个人搭建自己网站部署AI应用的简单过程,完整源码后期整理上传,欢迎大家留言关注~