上节课,我们做了一个简单的数据预处理,通过观察直方图,定义好了文本的长度参数。同时,如果对 Bert 不熟悉的同学,还需要看一下前面补充的内容。
现在,假设大家已经看了、并且掌握了前面 Huggingface 的内容,我们接着往下讲自定义 Dataset 和 Bert 分词的内容。
代码示例
1、新建文件
# utils.py from torch.utils import data from config import * import torch from transformers import BertTokenizer from transformers import logging logging.set_verbosity_error()
2、自定义Dataset类
class Dataset(data.Dataset):
def __init__(self, type='train'):
super().__init__()
if type == 'train':
sample_path = TRAIN_SAMPLE_PATH
elif type == 'dev':
sample_path = DEV_SAMPLE_PATH
elif type == 'test':
sample_path = TEST_SAMPLE_PATH
self.lines = open(sample_path).readlines()
self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
text, label = self.lines[index].split('\t')
tokened = self.tokenizer(text)
input_ids = tokened['input_ids']
mask = tokened['attention_mask']
if len(input_ids) < TEXT_LEN:
pad_len = (TEXT_LEN - len(input_ids))
input_ids += [BERT_PAD_ID] * pad_len
mask += [0] * pad_len
target = int(label)
return torch.tensor(input_ids[:TEXT_LEN]), torch.tensor(mask[:TEXT_LEN]), torch.tensor(target)
3、调用测试
if __name__ == '__main__':
dataset = Dataset()
loader = data.DataLoader(dataset, batch_size=2)
print(iter(loader).next())
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/439