去做具体的事,然后稳稳托举自己
—— 25.3.17
数据文件:
通过网盘分享的文件:Ner命名实体识别任务
链接: https://pan.baidu.com/s/1fUiin2um4PCS5i91V9dJFA?pwd=yc6u 提取码: yc6u
--来自百度网盘超级会员v3的分享
一、配置文件 config.py
1.模型与数据路径
model_path:模型训练完成后保存的位置。例如:保存最终的模型权重文件。
schema_path:数据结构定义文件,通常用于描述数据的格式(如字段名、标签类型)。
在NER任务中,可能定义实体类别(如 {"PERSON": "人名", "ORG": "组织"}
)。
train_data_path:训练数据集路径,通常为标注好的文本文件(如 train.txt
或 JSON
格式)。
valid_data_path: 验证数据集路径,用于模型训练时的性能评估和超参数调优。
vocab_path:字符词汇表文件,记录模型中使用的字符集(如中文字符、字母、数字等)。
2.模型架构
max_length:输入文本的最大序列长度。超过此长度的文本会被截断或填充(如用 [PAD]
)。
hidden_size:模型隐藏层神经元的数量,影响模型容量和计算复杂度。
num_layers:模型的堆叠层数(如LSTM、Transformer的编码器/解码器层数)。
class_num:任务类别总数。例如:NER任务中可能有9种实体类型。
vocab_size:词表大小
3.训练配置
epoch:训练轮数。每轮遍历整个训练数据集一次。
batch_size:每次梯度更新所使用的样本数量。较小的批次可能更适合内存受限的环境。
optimizer:优化器类型,用于调整模型参数。Adam是常用优化器,结合动量梯度下降。
learning_rate:学习率,控制参数更新的步长。值过小可能导致训练缓慢,过大易过拟合。
use_crf:是否启用条件随机场(CRF)层。在序列标注任务(如NER)中,CRF可捕捉标签间的依赖关系,提升准确性。
4.预训练模型
bert_path:预训练BERT模型的路径。BERT是一种强大的预训练语言模型,此处可能用于微调或特征提取。
# -*- coding: utf-8 -*-
"""
配置参数信息
"""
Config = {
"model_path": "model_output",
"schema_path": "ner_data/schema.json",
"train_data_path": "ner_data/train",
"valid_data_path": "ner_data/test",
"vocab_path":"chars.txt",
"max_length": 100,
"hidden_size": 256,
"num_layers": 2,
"epoch": 20,
"batch_size": 16,
"optimizer": "adam",
"learning_rate": 1e-3,
"use_crf": False,
"class_num": 9,
"bert_path": r"F:\人工智能NLP/NLP资料\week6 语言模型/bert-base-chinese",
"vocab_size": 20000
}
二、数据加载 loader.py
1.初始化数据加载类
def __init__(self, data_path, config):构造函数接收数据路径和配置对象。
data_path:数据文件存储路径
config:包含训练 / 数据配置的字典
self.config:保存包含训练 / 数据配置的字典
self.path:保存数据文件存储路径
self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具
self.sentences:初始化句子列表
self.schema:加载实体标签与索引的映射关系表
self.load:调用 load()
方法从 data_path
加载原始数据,进行分词、编码、填充/截断等预处理。
def __init__(self, data_path, config):
self.config = config
self.path = data_path
self.tokenizer = load_vocab(config["bert_path"])
self.sentences = []
self.schema = self.load_schema(config["schema_path"])
self.load()
2.加载数据并预处理
① 初始化数据容器 ——>
② 文件读取与分段处理 ——>
③ 逐段解析字符与标签 ——>
④ 句子编码与填充 ——>
⑤ 数据封装与返回
self.path:数据文件的存储路径(如 train.txt
),由类初始化时传入的 data_path
参数赋值。
f:文件对象,用于读取 self.path
指向的原始数据文件。
segments:是按双换行符分隔的段落列表,每个段落对应一个样本(如一个句子及其标注序列)。
segment:遍历 segments
时的单个样本段落,进一步按行分割处理为字符和标签
labels:存储当前样本的标签序列,[8]可能表示 [CLS]
标记的 ID,用于序列起始符,之后将每个字符的标签转换为ID。
char:当前行的字符(如 "中"
),属于句子中的一个基本单元。
lable:当前行的原始标签字符串(如 "B-LOC"
),尚未映射为 ID。
input_ids:将字符序列编码为模型输入所需的 ID 序列(如 BERT 分词后的 Token ID)
self.data:列表,存储预处理后的数据样本,每个样本由输入张量和标签张量组成
sentence:由字符列表拼接而成的完整句子(如 "中国科技大学"
),存入 self.sentences
供后续可视化或调试。
open():打开文件并返回文件对象,支持读/写/追加等模式。
参数名 | 类型 | 说明 |
---|---|---|
file |
字符串 | 文件路径(绝对/相对路径) |
mode |
字符串 | 打开模式(如 r -只读、w -写入、a -追加) |
encoding |
字符串 | 文件编码(如 utf-8 ,文本模式需指定) |
errors |
字符串 | 编码错误处理方式(如 ignore 、replace ) |
文件对象.read():读取文件内容,返回字符串或字节流
参数名 | 类型 | 说明 |
---|---|---|
size |
整数 | 可选,指定读取的字节数(默认读取全部内容) |
split():按分隔符分割字符串,返回子字符串列表
参数名 | 类型 | 说明 |
---|---|---|
delimiter |
字符串 | 分隔符(默认空格) |
maxsplit |
整数 | 可选,最大分割次数(默认-1表示全部) |
strip():去除字符串首尾指定字符(默认空白字符)
参数名 | 类型 | 说明 |
---|---|---|
chars |
字符串 | 可选,指定需去除的字符集合 |
join():用分隔符连接可迭代对象的元素,返回新字符串
参数名 | 类型 | 说明 |
---|---|---|
iterable |
可迭代对象 | 需连接的元素集合(如列表、元组) |
sep |
字符串 | 分隔符(默认空字符串) |
列表.append():在列表末尾添加元素
参数名 | 类型 | 说明 |
---|---|---|
obj |
任意类型 | 要添加的元素 |
def load(self):
self.data = []
with open(self.path, encoding="utf8") as f:
segments = f.read().split("\n\n")
for segment in segments:
sentence = []
labels = [8] # cls_token
for line in segment.split("\n"):
if line.strip() == "":
continue
char, label = line.split()
sentence.append(char)
labels.append(self.schema[label])
sentence = "".join(sentence)
self.sentences.append(sentence)
input_ids = self.encode_sentence(sentence)
labels = self.padding(labels, -1)
# print(self.decode(sentence, labels))
# input()
self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
return
3.加载字 / 词表
vocab_path:字 / 词表的存储路径
BertTokenizer.from_pretrained():Hugging Face Transformers 库中用于加载预训练 BERT 分词器的核心方法。它支持从 Hugging Face 模型库或本地路径加载预训练的分词器,并允许通过参数调整分词行为。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
pretrained_model_name_or_path |
str |
必填 | 预训练模型名称(如 bert-base-uncased )或本地路径。若为名称,自动从 Hugging Face 下载 |
cache_dir |
str |
None |
模型缓存目录。若指定,下载的模型文件会存储在此路径下 |
force_download |
bool |
False |
是否强制重新下载模型,即使本地已缓存 |
resume_download |
bool |
False |
是否断点续传下载任务 |
do_lower_case |
bool |
True (英文模型) |
是否将文本转为小写。中文模型需注意:若设为 False ,可能导致英文单词被识别为 [UNK] |
add_special_tokens |
bool |
True |
是否在输入文本中添加 [CLS] 和 [SEP] 等特殊标记 |
tokenize_chinese_chars |
bool |
True |
是否对中文字符进行逐字分词(如将“你好”拆分为“你”和“好”) |
strip_accents |
bool |
None |
是否去除重音符号(如将 é 转换为 e ) |
use_fast |
bool |
True |
是否启用快速分词模式(基于 Rust 实现,速度更快) |
def load_vocab(vocab_path):
return BertTokenizer.from_pretrained(vocab_path)
4.加载映射关系表
加载位于指定路径的 JSON 格式的模式文件,并将其内容解析为 Python 对象以便在数据生成过程中使用。
path:映射关系表schema的存储路径
open():打开文件并返回文件对象,用于读写文件内容。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
file_name |
str | 无 | 文件路径(需包含扩展名) |
mode |
str | 'r' |
文件打开模式: - 'r' : 只读- 'w' : 只写(覆盖原文件)- 'a' : 追加写入- 'b' : 二进制模式- 'x' : 创建新文件(若存在则报错) |
buffering |
int | None |
缓冲区大小(仅二进制模式有效) |
encoding |
str | None |
文件编码(仅文本模式有效,如 'utf-8' ) |
newline |
str | '\n' |
行结束符(仅文本模式有效) |
closefd |
bool | True |
是否在文件关闭时自动关闭文件描述符 |
dir_fd |
int | -1 |
文件描述符(高级用法,通常忽略) |
flags |
int | 0 |
Linux 系统下的额外标志位 |
mode |
str | 无 | (重复参数,实际使用中只需指定 mode ) |
json.load():从已打开的 JSON 文件对象中加载数据,并将其转换为 Python 对象(如字典、列表)。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
fp |
io.TextIO |
无 | 已打开的文件对象(需处于读取模式) |
indent |
int/str | None |
缩进空格数(美化输出,如 4 或 " " ) |
sort_keys |
bool | False |
是否对 JSON 键进行排序 |
load_hook |
callable | None |
自定义对象加载回调函数 |
object_hook |
callable | None |
自定义对象解析回调函数 |
def load_schema(self, path):
with open(path, encoding="utf8") as f:
return json.load(f)
5.封装数据
Ⅰ、初始化DataGenerator:初始化DataGenerator实例dg,传入data_path和config
Ⅱ、创建
DataLoader
对象:创建DataLoader实例dl,使用dg、batch_size和shuffle参数Ⅲ、返回
DataLoader
迭代器:返回dl
data_path:数据文件的路径(如 train.txt
),用于初始化 DataGenerator
,指向原始数据文件。
config:配置参数字典,通常包含 batch_size
、bert_path
、schema_path
等参数,用于控制数据加载逻辑。
dg:自定义数据集对象,继承 torch.utils.data.Dataset
,负责数据加载、预处理和样本生成。
dl:封装 DataGenerator
的迭代器,实现批量加载、多进程加速等功能,直接用于模型训练。
DataLoader():PyTorch 模型训练的标配工具,通过合理的参数配置(如 batch_size
、num_workers
、shuffle
),可以显著提升数据加载效率,尤其适用于大规模数据集和复杂预处理任务。其与 Dataset
类的配合使用,是构建高效训练管道的核心。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
dataset |
Dataset |
None |
必须参数,自定义数据集对象(需继承 torch.utils.data.Dataset )。 |
batch_size |
int | 1 |
每个批次的样本数量。 |
shuffle |
bool | False |
是否在每个 epoch 开始时打乱数据顺序(训练时推荐设为 True )。 |
num_workers |
int | 0 |
使用多线程加载数据的工人数量(需大于 0 时生效)。 |
pin_memory |
bool | False |
是否将数据存储在 pinned memory 中(加速 GPU 数据传输)。 |
drop_last |
bool | False |
如果数据集长度无法被 batch_size 整除,是否丢弃最后一个不完整的批次。 |
persistent_workers |
bool | False |
是否保持工作线程在 epoch 之间持续运行(减少多线程初始化开销)。 |
worker_init_fn |
callable | None |
自定义工作线程初始化函数。 |
# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
dg = DataGenerator(data_path, config)
dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
return dl
6.对于输入文本做截断 / 填充
Ⅰ、截断过长序列(超过预设最大长度)
Ⅱ、填充过短序列(用
pad_token
补齐到预设最大长度)
#补齐或截断输入的序列,使其可以在一个batch内运算
def padding(self, input_id, pad_token=0):
input_id = input_id[:self.config["max_length"]]
input_id += [pad_token] * (self.config["max_length"] - len(input_id))
return input_id
7.类内魔术方法
self.data:表示数据集对象本身存储的数据容器
index:表示数据集中某个样本的索引值,用于定位并返回特定位置的样本。
__len__():用于定义对象的“长度”,通过内置函数 len()
调用时返回该值。它通常用于容器类(如列表、字典、自定义数据结构),表示容器中元素的个数
__getitem__():允许对象通过索引或键值访问元素,支持 obj[index]
或 obj[key]
语法。它使对象表现得像序列(如列表)或映射(如字典)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
8.对于输入的文本编码
调用分词器编码(参数控制标准化)
self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具
self.tokenizer.encode():Hugging Face Transformers 库中 BertTokenizer
的核心方法,用于将原始文本转换为模型可处理的输入形式。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
text |
str 或 List[str] |
必填 | 输入文本(单句或句子对)。 |
text_pair |
str |
None |
第二段文本(用于句子对任务,如问答),与 text 拼接后生成 [CLS] text [SEP] text_pair [SEP] |
add_special_tokens |
bool |
True |
是否添加 [CLS] 和 [SEP] 标记。关闭后仅返回原始分词索引 |
max_length |
int |
512 |
最大序列长度。超长文本会被截断,不足则填充 |
padding |
str 或 bool |
False |
填充策略:True /'longest' (按批次最长填充)、'max_length' (按 max_length 填充) |
truncation |
str 或 bool |
False |
截断策略:True (按 max_length 截断)、'only_first' (仅截断第一句) |
return_tensors |
str |
None |
返回张量类型:
|
return_attention_mask |
bool |
True |
是否生成 attention_mask ,标识有效内容(1)与填充部分(0) |
def encode_sentence(self, text, padding=True):
return self.tokenizer.encode(text,
padding="max_length",
max_length=self.config["max_length"],
truncation=True)
9.对于编码后的输入文本作解码
(04+)
: 匹配以0
(B-LOCATION)开头,后接多个4
(I-LOCATION)的连续标签
(15+)
、(26+)
、(37+)
:分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。
sentence:输入的原句(添加 $
后的版本),用于根据标签索引提取实体文本。
lables:模型输出的标签序列,转换为字符串后通过正则匹配定位实体位置。
results:存储提取的实体,键为实体类型(如 "LOCATION"
),值为该类型实体的文本列表。
location:正则匹配结果,通过 span()
获取实体在 sentence
中的起止位置,用于提取具体文本片段。
join():将可迭代对象(列表、元组等)中的元素按指定分隔符连接成一个字符串。调用该方法的字符串作为分隔符。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
iterable |
可迭代对象 | 必填 | 需连接的元素集合,所有元素必须是字符串类型。若为空,返回空字符串。 |
str():将其他数据类型(整数、浮点数、布尔值等)转换为字符串类型。支持格式化输出和复杂对象的字符串表示。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
object |
任意类型 | 必填 | 需转换的对象,如整数、列表、字典等。 |
encoding |
字符串 | 可选 | 编码格式(仅对字节类型有效),如 utf-8 。 |
errors |
字符串 | 可选 | 编码错误处理策略,如 ignore 、replace 。 |
defaultdict():创建字典的子类,为不存在的键自动生成默认值。需指定 default_factory
(如 list
、int
)定义默认值类型。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
default_factory |
可调用对象或无参数函数 | None |
用于生成默认值的函数。若未指定,访问不存在的键会抛出 KeyError 。 |
**kwargs |
关键字参数 | 可选 | 其他初始化字典的键值对,如 name="Alice" 。 |
re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match
对象
参数名 | 类型 | 说明 |
---|---|---|
pattern |
str 或正则表达式对象 |
要匹配的正则表达式模式 |
string |
str |
要搜索的字符串 |
flags |
int (可选) |
正则匹配标志(如 re.IGNORECASE ) |
.span():返回正则匹配的起始和结束索引(左闭右开区间)
列表.append():向列表末尾添加单个元素,直接修改原列表
参数名 | 类型 | 说明 |
---|---|---|
element |
任意 | 要添加的元素 |
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence) + 2]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
print("location", s, e)
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
print("org", s, e)
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
print("per", s, e)
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
print("time", s, e)
results["TIME"].append(sentence[s:e])
return results
完整代码
DataLoader():PyTorch 中用于高效加载和管理数据集的核心工具
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
dataset |
Dataset |
必填 | 加载的数据集对象,需实现 __len__ 和 __getitem__ 方法 |
batch_size |
int |
1 |
每个批次包含的样本数 |
shuffle |
bool |
False |
是否在每个训练周期(epoch)开始时打乱数据顺序。若 sampler 被指定,则忽略此参数。 |
sampler |
Sampler |
None |
自定义数据采样策略(如随机采样 RandomSampler 或顺序采样 SequentialSampler ) |
batch_sampler |
Sampler |
None |
自定义批次采样策略(需与 batch_size 、shuffle 等参数互斥) |
num_workers |
int |
0 |
用于加载数据的子进程数。0 表示在主进程加载;大于 0 时启用多进程加速 |
collate_fn |
Callable |
None |
合并多个样本为批次的函数(如填充序列长度)。默认将 NumPy 数组转为 Tensor |
pin_memory |
bool |
False |
若为 True ,将数据复制到 CUDA 固定内存中,加速 GPU 数据传输 |
drop_last |
bool |
False |
若为 True ,丢弃最后一个不完整的批次(当数据集样本数无法被 batch_size 整除时) |
timeout |
float |
0 |
等待从子进程收集批次的超时时间(秒)。0 表示无限等待 |
worker_init_fn |
Callable |
None |
子进程初始化函数(如设置随机种子) |
prefetch_factor |
int |
2 |
每个子进程预加载的批次数量(需 num_workers > 0 ) |
persistent_workers |
bool |
False |
是否在训练周期结束后保留子进程(减少重复创建进程的开销) |
.shape: NumPy 数组或 PyTorch 张量的属性,用于获取数据的维度信息。
input():Python 的内置函数,用于从标准输入(如键盘)读取用户输入的字符串。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
prompt |
str |
"" |
可选提示信息,显示在输入前(如 input("请输入:") ) |
返回值 | str |
- | 返回用户输入的字符串,需手动转换为其他类型(如 int(input()) ) |
# -*- coding: utf-8 -*-
import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from transformers import BertTokenizer
"""
数据加载
"""
class DataGenerator:
def __init__(self, data_path, config):
self.config = config
self.path = data_path
self.tokenizer = load_vocab(config["bert_path"])
self.sentences = []
self.schema = self.load_schema(config["schema_path"])
self.load()
def load(self):
self.data = []
with open(self.path, encoding="utf8") as f:
segments = f.read().split("\n\n")
for segment in segments:
sentenece = []
labels = [8] # cls_token
for line in segment.split("\n"):
if line.strip() == "":
continue
char, label = line.split()
sentenece.append(char)
labels.append(self.schema[label])
sentence = "".join(sentenece)
self.sentences.append(sentence)
input_ids = self.encode_sentence(sentenece)
labels = self.padding(labels, -1)
# print(self.decode(sentence, labels))
# input()
self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
return
def encode_sentence(self, text, padding=True):
return self.tokenizer.encode(text,
padding="max_length",
max_length=self.config["max_length"],
truncation=True)
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence) + 2]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
print("location", s, e)
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
print("org", s, e)
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
print("per", s, e)
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
print("time", s, e)
results["TIME"].append(sentence[s:e])
return results
# 补齐或截断输入的序列,使其可以在一个batch内运算
def padding(self, input_id, pad_token=0):
input_id = input_id[:self.config["max_length"]]
input_id += [pad_token] * (self.config["max_length"] - len(input_id))
return input_id
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def load_schema(self, path):
with open(path, encoding="utf8") as f:
return json.load(f)
def load_vocab(vocab_path):
return BertTokenizer.from_pretrained(vocab_path)
# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
dg = DataGenerator(data_path, config)
dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
return dl
if __name__ == "__main__":
from config import Config
dg = DataGenerator("ner_data/train", Config)
dl = DataLoader(dg, batch_size=32)
for x, y in dl:
print(x.shape, y.shape)
print(x[1], y[1])
input()
三、模型建立 model.py
1.代码运行流程
输入 x → 嵌入层 → 双向LSTM → 全连接分类层 → 分支判断:
│
├── 有 target → CRF? → 是:计算 CRF 损失(通过维特比算法计算序列概率)
│ │
│ └→ 否:计算交叉熵损失(logits 展平后与标签计算交叉熵)
│
└── 无 target → CRF? → 是:解码最优标签序列(使用CRF的decode方法)
│
└→ 否:返回原始 logits(全连接层输出的未归一化分数)
2.模型初始化
代码运行流程
输入 x → BERT预训练模型 → 分类层 → 分支判断:
│
├── 有 target → CRF? → 是:计算 CRF 损失(通过转移矩阵计算序列联合概率)
│ │
│ └→ 否:计算交叉熵损失(logits 与标签的逐位置交叉熵)
│
└── 无 target → CRF? → 是:维特比解码最优路径(考虑标签转移约束)
│
└→ 否:返回原始 logits(全连接层输出的未归一化分数)
hidden_size:定义LSTM隐藏层的维度(即每个时间步输出的特征数量
vocab_size:词表大小,即嵌入层(Embedding)可处理的词汇总数
max_length:输入序列的最大长度,用于数据预处理(如截断或填充)
class_num:分类任务的类别数量,决定线性层(nn.Linear
)的输出维度
num_layers:堆叠的LSTM层数,用于增加模型复杂度
BertModel.from_pretrained():加载预训练的 BERT 模型,支持从本地或 Hugging Face 模型库加载
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
pretrained_model_name |
字符串 | 无 | 预训练模型名称或路径(如 bert-base-chinese ) |
config |
字典/类 | 默认配置 | 自定义模型配置,覆盖默认参数(如隐藏层维度、注意力头数) |
cache_dir |
字符串 | None |
模型缓存目录 |
output_hidden_states |
布尔值 | False |
是否返回所有隐藏层输出(用于特征提取) |
nn.Linear():实现全连接层的线性变换(y = xW^T + b
)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
in_features |
整数 | 无 | 输入特征维度(如词向量维度 hidden_size ) |
out_features |
整数 | 无 | 输出特征维度(如分类类别数 class_num ) |
bias |
布尔值 | True |
是否启用偏置项 |
CRF():条件随机场层,用于序列标注任务中约束标签转移逻辑。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
num_tags |
整数 | 无 | 标签类别数(如 class_num ) |
batch_first |
布尔值 | False |
输入张量是否为 (batch_size, seq_len) 格式 |
torch.nn.CrossEntropyLoss():计算交叉熵损失,常用于分类任务
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
ignore_index |
整数 | -1 |
忽略指定索引的标签(如填充符 -1 ) |
reduction |
字符串 | mean |
损失聚合方式(可选 none 、sum 、mean ) |
def __init__(self, config):
super(TorchModel, self).__init__()
hidden_size = config["hidden_size"]
vocab_size = config["vocab_size"] + 1
max_length = config["max_length"]
class_num = config["class_num"]
num_layers = config["num_layers"]
# self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)
# self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
self.classify = nn.Linear(hidden_size * 2, class_num)
self.crf_layer = CRF(class_num, batch_first=True)
self.use_crf = config["use_crf"]
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1) #loss采用交叉熵损失
3.前向计算
代码运行流程
输入 x → 嵌入层 → LSTM层 → 分类层 → 分支判断:
│
├── 有 target → CRF? → 是:计算 CRF 损失
│ │
│ └→ 否:计算交叉熵损失
│
└── 无 target → CRF? → 是:解码最优标签序列
│
└→ 否:返回预测 logits
x:输入序列的 Token ID 矩阵,代表一个批次的文本数据(如 [[101, 234, ...], [103, 456, ...]]
)。
target:真实标签序列(如实体标注),若不为 None
表示训练阶段,需计算损失;否则为预测阶段。
predict:分类层输出的每个位置标签的未归一化分数(logits),用于后续的 CRF 或交叉熵损失计算。
mask:标记序列中有效 Token 的位置(非填充部分),target.gt(-1)
表示标签值大于 -1
的位置有效。
gt():张量的逐元素比较函数,返回布尔型张量,标记输入张量中大于指定值的元素位置。常用于生成掩码(如忽略填充符)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
other |
Tensor/标量 | 无 | 比较的阈值或张量。若为标量,则张量中每个元素与该值比较;若为张量,需与输入张量形状相同。 |
out |
Tensor | None | 可选输出张量,用于存储结果。 |
shape():返回张量的维度信息,描述各轴的大小。
view():调整张量的形状,支持自动推断维度(通过-1
占位符)。常用于数据展平或维度转换。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
*shape |
可变参数 | 无 | 目标形状的维度序列,如view(2, 3) 或view(-1, 28) ,-1 表示自动计算。 |
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, target=None):
x = self.embedding(x) #input shape:(batch_size, sen_len)
x, _ = self.layer(x) #input shape:(batch_size, sen_len, input_dim)
predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)
if target is not None:
if self.use_crf:
mask = target.gt(-1)
# loss 是 crf 的相反数,即 - crf(predict, target, mask)
return - self.crf_layer(predict, target, mask, reduction="mean")
else:
#(number, class_num), (number)
return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
else:
if self.use_crf:
return self.crf_layer.decode(predict)
else:
return predict
4.选择优化器
代码运行流程
输入 config → 提取参数 → 分支判断:
│
├── optimizer == "adam" → 返回 Adam 优化器实例
│
└── optimizer == "sgd" → 返回 SGD 优化器实例
config:这个参数应该是一个字典,里面存储了配置信息。
model:这是传入的模型对象,通常是一个神经网络模型。优化器需要模型的参数来更新权重
optimizer:从config中获取的字符串,决定使用哪种优化器。比如"adam"对应Adam优化器,"sgd"对应随机梯度下降。
learning_rate:学习率,是优化器的一个重要超参数,控制权重更新的步长
Adam():自适应矩估计优化器(Adaptive Moment Estimation),结合动量和 RMSProp 的优点。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
lr | float | 1e-3 | 学习率。 |
betas | tuple | (0.9, 0.999) | 动量系数(β₁, β₂)。 |
eps | float | 1e-8 | 防止除零误差。 |
weight_decay | float | 0 | 权重衰减率。 |
amsgrad | bool | False | 是否启用 AMSGrad 优化。 |
foreach | bool | False | 是否为每个参数单独计算梯度。 |
SGD():随机梯度下降优化器(Stochastic Gradient Descent)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
lr | float | 1e-3 | 学习率。 |
momentum | float | 0 | 动量系数(如 momentum=0.9 )。 |
weight_decay | float | 0 | 权重衰减率。 |
dampening | float | 0 | 动力衰减系数(用于 SGD with Momentum)。 |
nesterov | bool | False | 是否启用 Nesterov 动量。 |
foreach | bool | False | 是否为每个参数单独计算梯度。 |
parameters():返回模型所有可训练参数的迭代器,常用于参数初始化或梯度清零。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
filter | callable | None | 过滤条件函数(如 lambda p: p.requires_grad )。默认返回所有参数。 |
def choose_optimizer(config, model):
optimizer = config["optimizer"]
learning_rate = config["learning_rate"]
if optimizer == "adam":
return Adam(model.parameters(), lr=learning_rate)
elif optimizer == "sgd":
return SGD(model.parameters(), lr=learning_rate)
5.模型建立
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
import torch
from transformers import BertModel
"""
建立网络模型结构
"""
class TorchModel(nn.Module):
def __init__(self, config):
super(TorchModel, self).__init__()
hidden_size = config["hidden_size"]
vocab_size = config["vocab_size"] + 1
max_length = config["max_length"]
class_num = config["class_num"]
num_layers = config["num_layers"]
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
# self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)
self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
self.classify = nn.Linear(hidden_size * 2, class_num)
self.crf_layer = CRF(class_num, batch_first=True)
self.use_crf = config["use_crf"]
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1) #loss采用交叉熵损失
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, target=None):
x = self.embedding(x) #input shape:(batch_size, sen_len)
x, _ = self.layer(x) #input shape:(batch_size, sen_len, input_dim)
predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)
if target is not None:
if self.use_crf:
mask = target.gt(-1)
# loss 是 crf 的相反数,即 - crf(predict, target, mask)
return - self.crf_layer(predict, target, mask, reduction="mean")
else:
#(number, class_num), (number)
return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
else:
if self.use_crf:
return self.crf_layer.decode(predict)
else:
return predict
def choose_optimizer(config, model):
optimizer = config["optimizer"]
learning_rate = config["learning_rate"]
if optimizer == "adam":
return Adam(model.parameters(), lr=learning_rate)
elif optimizer == "sgd":
return SGD(model.parameters(), lr=learning_rate)
if __name__ == "__main__":
from config import Config
model = TorchModel(Config)
四、模型效果测试 evaluate.py
1.代码运行流程
输入验证集 → 数据加载 → 模型预测 → 分支判断:
│
├── 启用CRF → 直接解码标签序列 → 实体提取
│
└── 禁用CRF → argmax获取预测标签 → 实体提取
→ 统计指标计算 → 分支判断:
│
├── 按实体类别统计 → 计算precision/recall/F1(LOCATION/TIME/PERSON/ORGANIZATION)
│
└── 全局统计 → 计算micro-F1 → 输出综合评估结果
2.初始化
Ⅰ、加载配置文件、模型及日志模块 ——>
Ⅱ、读取验证集数据(固定顺序,避免随机性干扰评估)——>
Ⅲ、初始化统计字典
stats_dict
,按实体类别记录正确识别数、样本实体数等
config:存储运行时配置,例如数据路径、超参数(如批次大小 batch_size
)、是否使用CRF层等。通过 config["valid_data_path"]
动态获取验证集路径。
model:待评估的模型实例,用于调用预测方法(如 model(input_id)
),需提前完成训练和加载。
logger:记录运行日志,例如输出评估指标(准确率、F1值)到文件或控制台,便于调试和监控。
valid_data:验证数据集,用于模型训练时的性能评估和超参数调优。
load_data():数据加载类中,用torch自带的DataLoader类封装数据的函数
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
3.统计模型效果
Ⅰ、输入验证与初始化
通过
assert
确保输入的三组数据长度一致(labels
,pred_results
,sentences
)。若模型未使用 CRF 层(
use_crf=False
),将预测结果通过torch.argmax
转换为标签索引序列
Ⅱ、逐样本处理
遍历每个样本的真实标签、预测标签及原始句子。
若未使用 CRF,将预测标签从 GPU Tensor 转换为 CPU List(避免内存泄漏)。
调用
decode()
方法解码标签序列,得到真实实体字典true_entities
和预测实体字典pred_entities
Ⅲ、实体统计
对每个实体类别(如
PERSON
,LOCATION
):正确识别数:遍历预测实体列表,统计与真实实体完全匹配的数量(
ent in true_entities[key]
)。样本实体数:统计真实实体列表的长度。
识别出实体数:统计预测实体列表的长度。
Ⅳ、输出统计结果
最终统计结果存储在
self.stats_dict
中,后续可通过该字典计算准确率(正确识别数 / 识别出实体数
)和召回率(正确识别数 / 样本实体数
)
labels:真实标签序列(如实体标注的整数 ID 列表),用于与预测结果对比计算评估指标
pred_results:模型预测结果,若使用 CRF,为标签序列,否则为每个位置的 logits(未归一化概率)。
sentences:原始文本句子列表(如 ["中国北京", "今天天气"]
),用于解码标签序列到具体实体。
use_crf:控制是否使用 CRF 层
pred_label:单个样本的预测标签序列,若未使用 CRF,需从 logits 中提取(argmax
)并转换为列表。
true_label:单个样本的真实标签序列(如 [0, 4, 4, 8]
),已从 GPU 张量转换为 CPU 列表。
true_entities:解码后的真实实体字典,如 {"LOCATION": ["北京"], "PERSON": []}
pred_entities:解码后的预测实体字典,用于与真实实体对比统计正确识别数。
key:字符串,实体类别名称(如 "PERSON"
),遍历四类实体以分别统计指标。
assert:Python 的 调试断言工具,主要用于在开发阶段验证程序内部的逻辑条件是否成立
assert expression [, message]
参数 | 类型 | 是否必填 | 作用 |
---|---|---|---|
expression | 布尔表达式 | 是 | 需要验证的条件。若结果为 False ,则触发断言失败;若为 True ,程序继续执行。 |
message | 字符串(可选) | 否 | 断言失败时输出的自定义错误信息,用于辅助调试。若省略,则输出默认错误提示。 |
len():返回对象的元素数量(字符串、列表、元组、字典等)
参数名 | 类型 | 说明 |
---|---|---|
object | 任意可迭代对象 | 如字符串、列表、字典等 |
torch.argmax():返回张量中最大值所在的索引
参数名 | 类型 | 说明 |
---|---|---|
input | Tensor | 输入张量 |
dim | int | 沿指定维度查找最大值 |
keepdim | bool | 是否保持输出维度一致 |
cpu():将张量从GPU移动到CPU内存
zip():将多个可迭代对象打包成元组列表
参数名 | 类型 | 说明 |
---|---|---|
iterables | 多个可迭代对象 | 如列表、元组、字符串 |
.detach():从计算图中分离张量,阻止梯度传播
.tolist():将张量或数组转换为Python列表
def write_stats(self, labels, pred_results, sentences):
assert len(labels) == len(pred_results) == len(sentences)
if not self.config["use_crf"]:
pred_results = torch.argmax(pred_results, dim=-1)
for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
if not self.config["use_crf"]:
pred_label = pred_label.cpu().detach().tolist()
true_label = true_label.cpu().detach().tolist()
true_entities = self.decode(sentence, true_label)
pred_entities = self.decode(sentence, pred_label)
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
self.stats_dict[key]["样本实体数"] += len(true_entities[key])
self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
return
4.可视化统计模型效果
精确率 (Precision):正确预测实体数 / 总预测实体数
召回率 (Recall):正确预测实体数 / 总真实实体数
F1值:精确率与召回率的调和平均
F1:F1分数:准确率与召回率的调和平均数,综合衡量模型的精确性与覆盖能力。
F1_scores:存储四个实体类别的 F1 分数,用于计算宏观平均。
precision:准确率:模型预测为某类实体的结果中,正确的比例。反映模型预测的精确度。
recall:召回率:真实存在的某类实体中,被模型正确识别的比例。反映模型对实体的覆盖能力。
key:当前处理的实体类别(如 "PERSON"
、"LOCATION"
)。
correct_pred:总正确识别数:所有类别中被正确识别的实体总数。
total_pred:总识别实体数:模型预测出的所有实体数量(含错误识别)。
true_enti:总样本实体数:验证数据中真实存在的所有实体数量。
micro_precision:微观准确率:全局视角下的准确率,所有实体类别的正确识别数与总识别数的比例。
micro_recall:微观召回率:全局视角下的召回率,所有实体类别的正确识别数与总样本实体数的比例。
micro_f1:微观F1分数:微观准确率与微观召回率的调和平均数。
列表.append():在列表末尾添加元素
参数名 | 类型 | 说明 |
---|---|---|
element | 任意 | 要添加的元素 |
logger.info():记录日志信息(需配置日志模块)
参数名 | 类型 | 说明 |
---|---|---|
format | str | 格式化字符串 |
*args | 可变参数 | 格式化参数 |
sum():计算可迭代对象的元素总和
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 如列表、元组 |
start | 数值(可选) | 初始累加值 |
列表推导式:通过简洁语法生成新列表,语法:[表达式 for item in iterable if 条件]
def show_stats(self):
F1_scores = []
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
F1 = (2 * precision * recall) / (precision + recall + 1e-5)
F1_scores.append(F1)
self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
micro_precision = correct_pred / (total_pred + 1e-5)
micro_recall = correct_pred / (true_enti + 1e-5)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
self.logger.info("Micro-F1 %f" % micro_f1)
self.logger.info("--------------------")
return
5.评估模型效果
模型切换为评估模式:关闭Dropout等训练层
批次处理数据:
提取原始句子
sentences
将数据迁移至GPU(若可用)
预测时禁用梯度计算(
torch.no_grad()
)优化内存统计结果:调用
write_stats
对比预测与真实标签
epoch:当前训练轮次,用于日志。
logger:记录日志的工具。
stats_dict:统计字典,记录各实体类别的指标。
valid_data:验证数据集,通常由 load_data
加载(如 config["valid_data_path"]
指定路径)
index: 循环中的批次索引
batch_data: 循环中的数据。
sentences:当前批次的原始句子
pred_results:模型预测结果
write_stats():写入统计信息
show_stats():显示统计结果
logger.info():记录日志信息(需配置日志模块)
参数名 | 类型 | 说明 |
---|---|---|
format | str | 格式化字符串 |
*args | 可变参数 | 格式化参数 |
defaultdict():创建带有默认值工厂的字典
参数名 | 类型 | 说明 |
---|---|---|
default_factory | 可调用对象 | 如int、list、自定义函数 |
model.eval():将模型设置为评估模式(关闭Dropout等训练层)
enumerate():返回索引和元素组成的枚举对象
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 如列表、字符串 |
start | int(可选) | 起始索引,默认为0 |
torch.cuda.is_available():检查当前环境是否支持CUDA(GPU加速)
cuda():将张量或模型移动到GPU
参数名 | 类型 | 说明 |
---|---|---|
device | int/str | 指定GPU设备号,如"cuda:0" |
torch.no_grad():禁用梯度计算,节省内存并加速推理
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.stats_dict = {"LOCATION": defaultdict(int),
"TIME": defaultdict(int),
"PERSON": defaultdict(int),
"ORGANIZATION": defaultdict(int)}
self.model.eval()
for index, batch_data in enumerate(self.valid_data):
sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results, sentences)
self.show_stats()
return
6.解码
根据代码中,Schema文件映射的定义对标签序列预处理:将数值标签拼接为字符串(如
[0,4,4]
→"044"
)正则匹配实体:
04+
:B-LOCATION(0)后接多个I-LOCATION(4)
15+
:B-ORGANIZATION(1)后接I-ORGANIZATION(5)其他实体类别同理
索引对齐:根据匹配位置截取原始句子中的实体文本
Ⅰ、输入预处理
在原句首添加 $
符号,通常用于对齐标签与字符位置(例如避免索引越界)
sentence = "$" + sentence
Ⅱ、标签序列转换
将整数标签序列转换为字符串,并截取长度与 sentence
对齐
str.join():将可迭代对象中的字符串元素按指定分隔符连接成一个新字符串
参数名 | 类型 | 说明 |
---|---|---|
iterable |
可迭代对象 | 元素必须为字符串类型 |
str():将对象转换为字符串表示形式,支持自定义类的 __str__
方法
参数名 | 类型 | 说明 |
---|---|---|
object |
任意 | 要转换的对象 |
len():返回对象的长度或元素个数(适用于字符串、列表、字典等)
参数名 | 类型 | 说明 |
---|---|---|
object |
可迭代对象 | 如字符串、列表等 |
列表推导式:通过简洁语法生成新列表,支持条件过滤和多层循环
[expression for item in iterable if condition]
部分 | 类型 | 说明 |
---|---|---|
expression |
表达式 | 对 item 处理后的结果 |
item |
变量 | 迭代变量 |
iterable |
可迭代对象 | 如列表、range() 生成的序列 |
condition |
条件表达式 (可选) | 过滤不符合条件的元素 |
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
Ⅲ、初始化结果容器
创建默认值为列表的字典,存储四类实体:
(LOCATION、ORGANIZATION、PERSON、TIME)的识别结果
defaultdict():创建默认值字典,当键不存在时自动生成默认值(基于工厂函数)
参数名 | 类型 | 说明 |
---|---|---|
default_factory |
可调用对象 | 如 int 、list 或自定义函数 |
results = defaultdict(list)
Ⅳ、正则表达式匹配
(04+)
: 匹配以0
(B-LOCATION)开头,后接多个4
(I-LOCATION)的连续标签
(15+)
、(26+)
、(37+)
:分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。
re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match
对象
参数名 | 类型 | 说明 |
---|---|---|
pattern |
str 或正则表达式对象 |
要匹配的正则表达式模式 |
string |
str |
要搜索的字符串 |
flags |
int (可选) |
正则匹配标志(如 re.IGNORECASE ) |
.span():返回正则匹配的起始和结束索引(左闭右开区间)
列表.append():向列表末尾添加单个元素,直接修改原列表
参数名 | 类型 | 说明 |
---|---|---|
element |
任意 | 要添加的元素 |
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
Ⅴ、完整代码
'''
Schema文件
{
"B-LOCATION": 0,
"B-ORGANIZATION": 1,
"B-PERSON": 2,
"B-TIME": 3,
"I-LOCATION": 4,
"I-ORGANIZATION": 5,
"I-PERSON": 6,
"I-TIME": 7,
"O": 8
}
'''
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
results["TIME"].append(sentence[s:e])
return results
7.完整代码
# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data
"""
模型效果测试
"""
class Evaluator:
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.stats_dict = {"LOCATION": defaultdict(int),
"TIME": defaultdict(int),
"PERSON": defaultdict(int),
"ORGANIZATION": defaultdict(int)}
self.model.eval()
for index, batch_data in enumerate(self.valid_data):
sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results, sentences)
self.show_stats()
return
def write_stats(self, labels, pred_results, sentences):
assert len(labels) == len(pred_results) == len(sentences)
if not self.config["use_crf"]:
pred_results = torch.argmax(pred_results, dim=-1)
for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
if not self.config["use_crf"]:
pred_label = pred_label.cpu().detach().tolist()
true_label = true_label.cpu().detach().tolist()
true_entities = self.decode(sentence, true_label)
pred_entities = self.decode(sentence, pred_label)
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
self.stats_dict[key]["样本实体数"] += len(true_entities[key])
self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
return
def show_stats(self):
F1_scores = []
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
F1 = (2 * precision * recall) / (precision + recall + 1e-5)
F1_scores.append(F1)
self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
micro_precision = correct_pred / (total_pred + 1e-5)
micro_recall = correct_pred / (true_enti + 1e-5)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
self.logger.info("Micro-F1 %f" % micro_f1)
self.logger.info("--------------------")
return
'''
{
"B-LOCATION": 0,
"B-ORGANIZATION": 1,
"B-PERSON": 2,
"B-TIME": 3,
"I-LOCATION": 4,
"I-ORGANIZATION": 5,
"I-PERSON": 6,
"I-TIME": 7,
"O": 8
}
'''
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
results["TIME"].append(sentence[s:e])
return results
五、主函数文件 main.py
1.代码运行流程
配置参数 → 创建模型目录 → 加载训练数据 → 初始化模型 → 设备检测:
│
├── GPU可用 → 迁移模型至GPU
│
└── GPU不可用 → 保持CPU模式
→ 选择优化器 → 初始化评估器 → 进入训练循环:
│
├── 当前epoch → 训练模式 → 遍历数据批次:
│ │
│ ├── 清空梯度 → 数据迁移至GPU → 前向计算 → 分支判断:
│ │ │
│ │ ├── 启用CRF → 计算CRF损失 → 反向传播 → 参数更新
│ │ │
│ │ └── 禁用CRF → 计算交叉熵损失 → 反向传播 → 参数更新
│ │
│ └── 记录批次损失 → 周期中点打印日志
│
└── 计算epoch平均损失 → 验证集评估 → 保存当前模型权重
2.导入文件
# -*- coding: utf-8 -*-
import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
3.日志配置
logging.basicConfig():配置日志系统的基础参数(一次性设置,应在首次日志调用前调用)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
filename |
字符串 | 否 | None |
日志输出文件名(若指定,日志写入文件而非控制台) |
filemode |
字符串 | 否 | 'a' |
文件打开模式(如'w' 覆盖,'a' 追加) |
format |
字符串 | 否 | 基础格式 | 日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s' ) |
datefmt |
字符串 | 否 | 无 | 时间格式(如'%Y-%m-%d %H:%M:%S' ) |
level |
整数 | 否 | WARNING |
日志级别(如logging.INFO 、logging.DEBUG ) |
stream |
对象 | 否 | None |
指定日志输出流(如sys.stderr ,与filename 互斥) |
logging.getLogger():获取或创建指定名称的日志记录器(Logger)。若name
为None
,返回根日志记录器
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
name |
字符串 | 否 | None |
日志记录器名称(分层结构,如'module.sub' ) |
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
4.主函数 main
Ⅰ、创建模型保存目录
os.path.isdir():检查指定路径是否为目录(文件夹)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
path |
字符串 | 是 | 无 | 要检查的路径(绝对或相对) |
os.mkdir():创建单个目录(若父目录不存在会抛出异常)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
path |
字符串 | 是 | 无 | 要创建的目录路径 |
mode |
整数 | 否 | 0o777 |
目录权限(八进制格式,某些系统可能忽略此参数) |
#创建保存模型的目录
if not os.path.isdir(config["model_path"]):
os.mkdir(config["model_path"])
Ⅱ、加载训练数据
#加载训练数据
train_data = load_data(config["train_data_path"], config)
Ⅲ、加载模型
#加载模型
model = TorchModel(config)
Ⅳ、检查GPU并迁移模型
torch.cuda.is_available():检查系统是否满足 CUDA 环境要求
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg |
str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args |
Any | 否 | 格式化参数(用于% 占位符) |
cuda():将张量或模型移动到GPU显存,加速计算
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
device |
int/str | 否 | 指定GPU设备(如0 或"cuda:0" ) |
tensor.cuda(device=0) |
non_blocking |
bool | 否 | 是否异步传输数据(默认False) | tensor.cuda(non_blocking=True) |
# 标识是否使用gpu
cuda_flag = torch.cuda.is_available()
if cuda_flag:
logger.info("gpu可以使用,迁移模型至gpu")
model = model.cuda()
Ⅴ、加载优化器
#加载优化器
optimizer = choose_optimizer(config, model)
Ⅵ、加载评估器
#加载效果测试类
evaluator = Evaluator(config, model, logger)
Ⅶ、模型训练 ⭐
① Epoch循环控制
range():Python 内置函数,用于生成一个不可变的整数序列,核心功能是为循环控制提供高效的数值迭代支持
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
start |
整数 | 0 |
序列起始值(包含)。若省略,则默认从 0 开始。例如 range(3) 等价于 range(0,3) 。 |
stop |
整数 | 必填 | 序列结束值(不包含)。例如 range(2, 5) 生成 2,3,4 |
step |
整数 | 1 |
步长(正/负): - 正步长需满足 start < stop ,否则无输出(如 range(5, 2) 无效)。- 负步长需满足 start > stop ,例如 range(5, 0, -1) 生成 5,4,3,2,1 **不能为 0 **(否则触发 ValueError ) |
for epoch in range(config["epoch"]):
epoch += 1
② 模型设置训练模式
train_loss:计算当前批次的损失值,通常结合损失函数(如交叉熵、均方误差)使用
model.train():设置模型为训练模式,启用Dropout、BatchNorm等层的训练行为
参数 | 类型 | 默认值 | 说明 | 示例 |
---|---|---|---|---|
mode |
bool | True |
是否启用训练模式(True)或评估模式(False) | model.train(True) |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg |
str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args |
Any | 否 | 格式化参数(用于% 占位符) |
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
③ Batch数据遍历
enumerate():遍历可迭代对象时返回索引和元素,支持自定义起始索引
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
iterable |
Iterable | 是 | 可迭代对象(如列表、生成器) | enumerate(["a", "b"]) |
start |
int | 否 | 索引起始值(默认0) | enumerate(data, start=1) |
for index, batch_data in enumerate(train_data):
④ 梯度清零与设备切换
optimizer.zero_grad():清空模型参数的梯度,防止梯度累积
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
set_to_none |
bool | 否 | 是否将梯度置为None (高效但危险) |
optimizer.zero_grad(True) |
cuda():将张量或模型移动到GPU显存,加速计算
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
device |
int/str | 否 | 指定GPU设备(如0 或"cuda:0" ) |
tensor.cuda(device=0) |
non_blocking |
bool | 否 | 是否异步传输数据(默认False) | tensor.cuda(non_blocking=True) |
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
⑤ 前向传播与损失计算
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
⑥ 反向传播与参数更新
loss.backward():反向传播计算梯度,基于损失值更新模型参数的.grad
属性
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
retain_graph |
bool | 否 | 是否保留计算图(用于多次反向传播) | loss.backward(retain_graph=True) |
optimizer.step():根据梯度更新模型参数,执行优化算法(如SGD、Adam)
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
closure |
Callable | 否 | 重新计算损失的闭包函数(如LBFGS) | optimizer.step(closure) |
loss.backward()
optimizer.step()
⑦ 损失记录与日志输出
列表.append():在列表末尾添加元素,直接修改原列表
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
object |
Any | 是 | 要添加到列表末尾的元素 | train_loss.append(loss.item()) |
int():将字符串或浮点数转换为整数,支持进制转换
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
x |
str/float | 是 | 待转换的值(如字符串或浮点数) | int("10", base=2) (输出2进制10=2) |
base |
int | 否 | 进制(默认10) |
len():返回对象(如列表、字符串)的长度或元素个数
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
obj |
Sequence/Collection | 是 | 可计算长度的对象(如列表、字符串) | len([1, 2, 3]) (返回3) |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg |
str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args |
Any | 否 | 格式化参数(用于% 占位符) |
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
⑧ Epoch评估与日志
item():从张量中提取标量值(仅当张量包含单个元素时可用)
列表.append():Python 列表(list)的内置方法,用于向列表的 末尾 添加一个元素。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
element |
任意类型 | 无 | 要添加到列表末尾的元素。可以是单个值(如 42 )、对象(如 [1, 2, 3] )等。 |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg |
str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args |
Any | 否 | 格式化参数(用于% 占位符) |
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
⑨ 完整训练代码
#训练
for epoch in range(config["epoch"]):
epoch += 1
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
for index, batch_data in enumerate(train_data):
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
logger.info("epoch average loss: %f" % np.mean(train_loss))
evaluator.eval(epoch)
Ⅷ、模型保存
os.path.join():Python 中用于拼接路径的核心函数,其核心价值在于自动处理不同操作系统的路径分隔符,从而保证代码的跨平台兼容性
参数 | 类型 | 必填 | 说明 |
---|---|---|---|
path1 |
字符串 | 是 | 初始路径组件 |
*paths |
可变参数 | 否 | 后续路径组件(可传多个) |
torch.save(): PyTorch 中用于序列化保存模型、张量或字典等对象的核心函数,支持将数据持久化存储为 .pth
或 .pt
文件,便于后续加载和复用
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
obj |
任意 PyTorch 对象 | 必填 | 待保存的对象,如模型、张量或字典。 |
f |
str 或文件对象 |
必填 | 保存路径(如 'model.pth' )或已打开的文件对象(需二进制写入模式 'wb' ) |
pickle_protocol |
int |
2 |
指定 pickle 协议版本(通常无需修改,高版本可能提升效率但需兼容性验证) |
_use_new_zipfile_serialization |
bool |
True |
启用新版序列化格式(压缩率更高,推荐保持默认) |
model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
# torch.save(model.state_dict(), model_path)
return model, train_data
5.调用模型预测
# -*- coding: utf-8 -*-
import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
"""
模型训练主程序
"""
def main(config):
#创建保存模型的目录
if not os.path.isdir(config["model_path"]):
os.mkdir(config["model_path"])
#加载训练数据
train_data = load_data(config["train_data_path"], config)
#加载模型
model = TorchModel(config)
# 标识是否使用gpu
cuda_flag = torch.cuda.is_available()
if cuda_flag:
logger.info("gpu可以使用,迁移模型至gpu")
model = model.cuda()
#加载优化器
optimizer = choose_optimizer(config, model)
#加载效果测试类
evaluator = Evaluator(config, model, logger)
#训练
for epoch in range(config["epoch"]):
epoch += 1
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
for index, batch_data in enumerate(train_data):
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
logger.info("epoch average loss: %f" % np.mean(train_loss))
evaluator.eval(epoch)
# 保存模型
model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
torch.save(model.state_dict(), model_path)
return model, train_data
if __name__ == "__main__":
model, train_data = main(Config)