用Transformers和FastAPI快速搭建后端算法api

发布于:2024-12-07 ⋅ 阅读:(33) ⋅ 点赞:(0)

用Transformers和FastAPI快速搭建后端算法api

如果你对自然语言处理 (NLP, Natural Language Processing) 感兴趣,想要了解ChatGPT是怎么来的,想要搭建自己的聊天机器人,想要做分类、翻译、摘要等各种NLP任务,HuggingfaceTransformers库是你不能错过的选择。在Transformers的帮助下,许多任务只需几行代码便能轻松解决。

预训练模型

在Transformer结构中,存在许多的 W W W权重矩阵,训练后具有大量的信息,并且可以保存为文件。如果有大佬训练了效果特别好的模型,他能不能慷慨分享出来,给我们用?幸运的是,真的存在这样一群具有无私开源精神的人,为推动科技发展造福全人类而不懈努力,这便是Huggingface社区的贡献者们。

让我们进入Huggingface的模型库>点击直达一探究竟。

印入眼帘的便是大量的预训练模型,包括著名的BERTGPT2等,常可以看到以下字眼:

  • large/base/small/tiny/mini:指模型大小;
  • cased/uncased:cased指区分大小写,”Hi”和”hi”是两个不同的词;uncased指预先对文本做lower-case,”Hi”和”hi”都会转成”hi”,于是词表只有小写的单词;
  • zh/fr/en:指语言,中文/法语/英语等;
  • wwm:Whole Word Masking,对全词进行Mask。

在模型首页有三个选项卡Model Card Files Community,作用如下:

  • Model Card:介绍这个模型,对应Files里的README.md。一般都会提供How to use,能用Transformers的几行代码把此模型跑起来。有些模型的右侧会提供Hosted inference API,直接在线尝试使用这个模型;底下还有Spaces,指使用了这些模型的空间,设计了界面,不过俩蛮多时候都不好使。
  • Files:保存了模型的各个文件。点击中间的下载图标即可下载,我试了右键另存为不太行。用PyTorch必要下载:
    1. pytorch_model.bin,模型的参数文件,最大(其他大文件基本都是针对其他类深度学习框架,如TensorFlow的tf_model.h5,不用下载);
    2. config.json,保存模型结构、参数等信息;
    3. vocab.txt(.json),词表文件;
    4. 其他小文件不是每个模型都有,但也请一并下载,可能包括特殊分词和模型结构的.py文件,不可或缺。
  • Community:提供模型的讨论区,结合了GitHub的issue和PR,对模型的问题或改进可以提在这。

开箱即用的pipeline

在Model Card中常见

from transformers import pipeline

使用pipeline,首先需要安装transformers,在pytorch环境配置好的情况下,只需:

pip install transformers

pipeline封装了预训练模型和对应的前处理和后处理环节,可以直接进行包括但不限于以下任务:

任务 描述 模态 Pipeline使用方式
Text classification 给一段文本赋上标签 NLP pipeline(task=“sentiment-analysis”)
Text generation 生成文本 NLP pipeline(task=“text-generation”)
Summarization 生成一段文本或文档的摘要 NLP pipeline(task=“summarization”)
Image classification 给图片赋上标签 Computer vision pipeline(task=“image-classification”)
Image segmentation 给图片的每个像素赋上标签 Computer vision pipeline(task=“image-segmentation”)
Object detection 预测图像中的对象边界和类别 Computer vision pipeline(task=“object-detection”)
Audio classification 给音频赋上标签 Audio pipeline(task=“audio-classification”)
Automatic speech recognition 语音转文本 Audio pipeline(task=“automatic-speech-recognition”)
Visual question answering 给定图片和问题,给出回答 Multimodal pipeline(task=“vqa”)
Document question answering 给定带有文档(例如表格)的图片和问题,给出回答 Multimodal pipeline(task=“document-question-answering”)
Image captioning 给图片起个标题 Multimodal pipeline(task=“image-to-text”)

Pipeline使用案例

例如情感分析,我们只需要输入文本,就可以得到其情感标签(积极/消极)以及对应的概率:

from transformers import pipeline

classifier = pipeline("sentiment-analysis")
result = classifier("I've been waiting for a HuggingFace course my whole life.")
print(result)

pipeline会自动完成以下三个步骤:

  1. 将文本预处理为模型的输入格式;
  2. 将预处理好的文本送入模型;
  3. 对模型的预测值进行后处理,输出想要的结果。

pipeline会自动选择合适的预训练模型来完成任务。例如对于情感分析,默认就会选择微调好的英文情感模型 distilbert-base-uncased-finetuned-sst-2-english

要想指定模型,可以加入model参数,如

generator = pipeline("text-generation", model="distilgpt2")

Transformers 库在首次加载时某模型时,会自动缓存模型到~/.cache/huggingface,再次使用时,直接调用缓存好的模型。

如果你也不喜欢这种被人瞒着的感觉,最好下载模型,在本地调用。比如下载好了bert-base-chinese到当前目录,再运行:

classifier = pipeline('text-classification', model='./bert-base-chinese')

这时会报一个warning,说权重没有初始化什么的,这是因为bert本身没有对文本分类任务做训练,所以没有关于分类的权重,需要自己训练,这个会在后面的章节进行学习。

FastAPI搭建后端算法接口

FastAPI是最快的基于 Python 的 Web 框架之一。相比于Flask,有以下特点:

  • FastAPI使用 Pydantic 进行数据验证;
  • Uvicorn提供异步请求能力。

先安装:

pip install fastapi uvicorn

以翻译任务为例,首先新建一个api.py文件,用transformers的pipeline搭建模型:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline

TRANS_MODEL_PATH = 'liam168/trans-opus-mt-en-zh'
def trans(text):
    model = AutoModelForSeq2SeqLM.from_pretrained(TRANS_MODEL_PATH)
    tokenizer = AutoTokenizer.from_pretrained(TRANS_MODEL_PATH)
    translation = pipeline("translation_en_to_zh", model=model, tokenizer=tokenizer)
    return translation(text, max_length=400)

用pydantic定义数据格式,翻译任务就只需要接收模型名称输入文本

from pydantic import BaseModel

class Data(BaseModel):
    model_name: str
    input: str

FastAPI构建服务:

from fastapi import FastAPI

app = FastAPI()
@app.post('/api/trans')
def translate(data_zh: Data):
    res = {
        "code": 200,
        "message": "OK"
    }
    msg = data_zh.input
    if data_zh.model_name == 'opus':
        trans_res = trans(msg)
        res.update({data_zh.model_name: trans_res[0]['translation_text']})
    return res

网站公告

今日签到

点亮在社区的每一天
去签到