zero-shot文字分类模型

发布于:2025-03-17 ⋅ 阅读:(15) ⋅ 点赞:(0)

我们选择facebook的bart-large-mnli模型做分类处理,bart-large-mnli是基于自然语言推理(Natural Language Inference)的零次文字分类模型。

首先需要安装Huggingface的transformers.

pip install transformers

Huggingface的介绍页面:https://huggingface.co/facebook/bart-large-mnli。
如下示例代码,来判断 "one day I will see the world"这句话,属于三个分类中的哪一个(travel,cooking,dancing):

from transformers import pipeline

zero_shot_classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli"
)

sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']

classfication = zero_shot_classifier(
    sequence_to_classify,
    candidate_labels,
    multi_label=True      # 可匹配多个label分类
)
# print(classfication)

print(classfication['labels'])
print(classfication['scores'])

运行之后,得到如下结果,travel标签最为匹配,得分为0.99:

['travel', 'dancing', 'cooking']
[0.994511067867279, 0.005706176161766052, 0.0018192899879068136]

更换一组如下数据:

sequence_to_classify = "The company's quarterly earnings increased by 20%, exceeding market expectations."
candidate_labels = ["finance", "sports", "politics", "technology"]

得到结果如下,finance标签最为匹配,得分0.85:

['finance', 'technology', 'sports', 'politics']
[0.8526014089584351, 0.5210740566253662, 0.004615711513906717, 0.000885962916072458]

这次,给定一个网址,而非之前的自然语言,测试下分类情况:

sequence_to_classify = "https://nba.sina.com.cn/"

得到结果如下,sports标签最为匹配,得分为0.83。

['sports', 'technology', 'finance', 'politics']
[0.8341358304023743, 0.20186419785022736, 0.07376285642385483, 0.004406137391924858]

网站公告

今日签到

点亮在社区的每一天
去签到