经过前面的数据预处理操作,我们已经得到了训练样本、测试样本文件,以及词表和标签表的缓存文件,接下来就是定义Dataset类,来加载数据。另外在加载数据时,需要将文本切分为等长的句子。

代码示例

1、新建文件

# utils.py
import torch
from torch.utils import data
from config import *
import pandas as pd

2、加载词表和标签表

def get_vocab():
    df = pd.read_csv(VOCAB_PATH, names=['word', 'id'])
    return list(df['word']), dict(df.values)

def get_label():
    df = pd.read_csv(LABEL_PATH, names=['label', 'id'])
    return list(df['label']), dict(df.values)

3、Dataset

type 参数,这个类是训练和测试公用的,所以定义一个参数来区分加载哪个文件。

base_len 参数,用来定义句子的参考长度,特殊情况再稍做处理。

class Dataset(data.Dataset):
    def __init__(self, type='train', base_len=50):
        super().__init__()
        self.base_len = base_len
        sample_path = TRAIN_SAMPLE_PATH if type == 'train' else TEST_SAMPLE_PATH
        self.df = pd.read_csv(sample_path, names=['word', 'label'])
        _, self.word2id = get_vocab()
        _, self.label2id = get_label()

if __name__ == '__main__':
    dataset = Dataset()

切分文本

目前我们已经定义好了Dataset类,并将文本加载到DataFrame中,下一步,就是要将长文本拆分成句子。我们这里采用等长切分,每隔50个字切一刀。

但有一种情况需要处理,就是切点上是非O标签,则需要将切点往后移动,直达O标签为止。

1、计算分割点

def get_points(self):
    self.points = [0]
    i = 0
    while True:
        if i + self.base_len >= len(self.df):
            self.points.append(len(self.df))
            break
        if self.df.loc[i + self.base_len, 'label'] == 'O':
            i += self.base_len
            self.points.append(i)
        else:
            i += 1

2、文本数字化

def __len__(self):
    return len(self.points) - 1

def __getitem__(self, index):
    df = self.df[self.points[index]:self.points[index + 1]]
    word_unk_id = self.word2id[WORD_UNK]
    label_o_id = self.label2id['O']
    input = [self.word2id.get(w, word_unk_id) for w in df['word']]
    target = [self.label2id.get(l, label_o_id) for l in df['label']]
    return input, target

本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/392