到上一节为止,我们已经处理好了模型需要的数据,本节就正式开始搭建模型了。
模型的整体是 BiLSTM+CRF 结构,但 CRF 涉及到解码和特殊的损失值计算方法,所以本节先介绍 BiLSTM,下一小节再介绍 CRF 部分。
模型结构
参考文档:https://createmomo.github.io/
代码示例
1、定义模型参数
# config.py VOCAB_SIZE = 3000 EMBEDDING_DIM = 100 HIDDEN_SIZE = 256 TARGET_SIZE = 31 LR = 1e-3 EPOCH = 100 MODEL_DIR = './output/model/'
2、导入模块
# model.py import torch.nn as nn from config import * from torchcrf import CRF import torch
3、定义模型
class Model(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM, WORD_PAD_ID)
self.lstm = nn.LSTM(
EMBEDDING_DIM,
HIDDEN_SIZE,
batch_first=True,
bidirectional=True,
)
self.linear = nn.Linear(2 * HIDDEN_SIZE, TARGET_SIZE)
def _get_lstm_feature(self, input):
out = self.embed(input)
out, _ = self.lstm(out)
return self.linear(out)
def forward(self, input, mask):
out = self._get_lstm_feature(input)
return out
if __name__ == '__main__':
model = Model()
input = torch.randint(1, 3000, (100, 50))
print(model)
print(model(input))
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/394