经过上一步的处理,我们将长文本切分成了单句,但每个句子长度又不完全一样。在批处理时,要求每个batch的句子长度必须一致,所以我们需要填充 PAD,来保证句子每个batch的句子长度一样。

另外,在后续模型 CRF 阶段计算损失时,可以通过MASK,将填充的 PAD 数值忽略掉,以消除填充值 PAD 的影响

综上,我们需要在 DataLoader 批量加载数据阶段,填充 PAD 来保证批数据长度一致,且需要记录 MASK。

代码示例

1、添加配置项

WORD_PAD_ID = 0
WORD_UNK_ID = 1
LABEL_O_ID = 0

2、DataLader()

if __name__ == '__main__':
    dataset = Dataset()
    loader = data.DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
    print(iter(loader).next())

3、数据校对整理

先按句子长度从大到小排序,获取最大长度,其他句子填充到跟他一样长。

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    max_len = len(batch[0][0])
    input = []
    target = []
    mask = []
    for item in batch:
        pad_len = max_len - len(item[0])
        input.append(item[0] + [WORD_PAD_ID] * pad_len)
        target.append(item[1] + [LABEL_O_ID] * pad_len)
        mask.append([1] * len(item[0]) + [0] * pad_len)
    return torch.tensor(input), torch.tensor(target), torch.tensor(mask).bool()

本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/393