经过前面的数据预处理操作,我们已经得到了训练样本、测试样本文件,以及词表和标签表的缓存文件,接下来就是定义Dataset类,来加载数据。另外在加载数据时,需要将文本切分为等长的句子。
代码示例
1、新建文件
# utils.py import torch from torch.utils import data from config import * import pandas as pd
2、加载词表和标签表
def get_vocab():
df = pd.read_csv(VOCAB_PATH, names=['word', 'id'])
return list(df['word']), dict(df.values)
def get_label():
df = pd.read_csv(LABEL_PATH, names=['label', 'id'])
return list(df['label']), dict(df.values)
3、Dataset
type 参数,这个类是训练和测试公用的,所以定义一个参数来区分加载哪个文件。
base_len 参数,用来定义句子的参考长度,特殊情况再稍做处理。
class Dataset(data.Dataset):
def __init__(self, type='train', base_len=50):
super().__init__()
self.base_len = base_len
sample_path = TRAIN_SAMPLE_PATH if type == 'train' else TEST_SAMPLE_PATH
self.df = pd.read_csv(sample_path, names=['word', 'label'])
_, self.word2id = get_vocab()
_, self.label2id = get_label()
if __name__ == '__main__':
dataset = Dataset()
切分文本
目前我们已经定义好了Dataset类,并将文本加载到DataFrame中,下一步,就是要将长文本拆分成句子。我们这里采用等长切分,每隔50个字切一刀。
但有一种情况需要处理,就是切点上是非O标签,则需要将切点往后移动,直达O标签为止。
1、计算分割点
def get_points(self):
self.points = [0]
i = 0
while True:
if i + self.base_len >= len(self.df):
self.points.append(len(self.df))
break
if self.df.loc[i + self.base_len, 'label'] == 'O':
i += self.base_len
self.points.append(i)
else:
i += 1
2、文本数字化
def __len__(self):
return len(self.points) - 1
def __getitem__(self, index):
df = self.df[self.points[index]:self.points[index + 1]]
word_unk_id = self.word2id[WORD_UNK]
label_o_id = self.label2id['O']
input = [self.word2id.get(w, word_unk_id) for w in df['word']]
target = [self.label2id.get(l, label_o_id) for l in df['label']]
return input, target
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/392