我们选择facebook的bart-large-mnli模型做分类处理,bart-large-mnli是基于自然语言推理(Natural Language Inference)的零次文字分类模型。
首先需要安装Huggingface的transformers.
pip install transformers
Huggingface的介绍页面:https://huggingface.co/facebook/bart-large-mnli。
如下示例代码,来判断 "one day I will see the world"这句话,属于三个分类中的哪一个(travel,cooking,dancing):
from transformers import pipeline
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli"
)
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classfication = zero_shot_classifier(
sequence_to_classify,
candidate_labels,
multi_label=True # 可匹配多个label分类
)
# print(classfication)
print(classfication['labels'])
print(classfication['scores'])
运行之后,得到如下结果,travel标签最为匹配,得分为0.99:
['travel', 'dancing', 'cooking']
[0.994511067867279, 0.005706176161766052, 0.0018192899879068136]
更换一组如下数据:
sequence_to_classify = "The company's quarterly earnings increased by 20%, exceeding market expectations."
candidate_labels = ["finance", "sports", "politics", "technology"]
得到结果如下,finance标签最为匹配,得分0.85:
['finance', 'technology', 'sports', 'politics']
[0.8526014089584351, 0.5210740566253662, 0.004615711513906717, 0.000885962916072458]
这次,给定一个网址,而非之前的自然语言,测试下分类情况:
sequence_to_classify = "https://nba.sina.com.cn/"
得到结果如下,sports标签最为匹配,得分为0.83。
['sports', 'technology', 'finance', 'politics']
[0.8341358304023743, 0.20186419785022736, 0.07376285642385483, 0.004406137391924858]