上节课当中,给大家介绍了这个项目需要用到的数据集,并且做了简单的数据预处理,缓存好了关系分类文件。接下来,我们可以定义Dataset类,来加载数据了。
但这个模型的输入参数和目标值比较复杂,我们拆分成三节课来处理。这节课,先完成文件加载和分词这两块内容。
代码示例
1、添加配置项
# config.py TRAIN_JSON_PATH = './data/input/duie/duie_train.json' TEST_JSON_PATH = './data/input/duie/duie_test.json' DEV_JSON_PATH = './data/input/duie/duie_dev.json' BERT_MODEL_NAME = 'bert-base-chinese'
2、新建文件
# utils.py import torch.utils.data as data import pandas as pd import random from config import * import json from transformers import BertTokenizerFast3、加载关系表
def get_rel():
df = pd.read_csv(REL_PATH, names=['rel', 'id'])
return df['rel'].tolist(), dict(df.values)
4、Dataset初始化
class Dataset(data.Dataset):
def __init__(self, type='train'):
super().__init__()
_, self.rel2id = get_rel()
# 加载文件
if type == 'train':
file_path = TRAIN_JSON_PATH
elif type == 'test':
file_path = TEST_JSON_PATH
elif type == 'dev':
file_path = DEV_JSON_PATH
with open(file_path) as f:
self.lines = f.readlines()
# 加载bert
self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
line = self.lines[index]
info = json.loads(line)
tokenized = self.tokenizer(info['text'], return_offsets_mapping=True)
info['input_ids'] = tokenized['input_ids']
info['offset_mapping'] = tokenized['offset_mapping']
print(info)
exit()
5、尝试加载数据集
if __name__ == '__main__':
dataset = Dataset()
loader = data.DataLoader(dataset)
print(iter(loader).next())
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/423