到上一节为止,我们已经处理好了模型需要的数据,本节就正式开始搭建模型了。

模型的整体是 BiLSTM+CRF 结构,但 CRF 涉及到解码和特殊的损失值计算方法,所以本节先介绍 BiLSTM,下一小节再介绍 CRF 部分。

模型结构

参考文档:https://createmomo.github.io/

代码示例

1、定义模型参数

# config.py
VOCAB_SIZE = 3000
EMBEDDING_DIM = 100
HIDDEN_SIZE = 256
TARGET_SIZE = 31
LR = 1e-3
EPOCH = 100

MODEL_DIR = './output/model/'

2、导入模块

# model.py
import torch.nn as nn
from config import *
from torchcrf import CRF
import torch

3、定义模型

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM, WORD_PAD_ID)
        self.lstm = nn.LSTM(
            EMBEDDING_DIM,
            HIDDEN_SIZE,
            batch_first=True,
            bidirectional=True,
        )
        self.linear = nn.Linear(2 * HIDDEN_SIZE, TARGET_SIZE)

    def _get_lstm_feature(self, input):
        out = self.embed(input)
        out, _ = self.lstm(out)
        return self.linear(out)

    def forward(self, input, mask):
        out = self._get_lstm_feature(input)
        return out

if __name__ == '__main__':
    model = Model()
    input = torch.randint(1, 3000, (100, 50))
    print(model)
    print(model(input))

本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/394