上节课,给大家介绍了 TextCNN 的模型结构,这节课就正式进入代码部分。本节课有两个任务,一是导入数据集,二是要统计待分类的文本长度,因为 TextCNN 在卷积之后,要做批量最大池化操作,所以要求文本长度一致,不够的填充PAD,太长的要进行截取。
代码示例
1、添加配置项
# config.py TRAIN_SAMPLE_PATH = './data/input/train.txt' DEV_SAMPLE_PATH = './data/input/dev.txt' TEST_SAMPLE_PATH = './data/input/test.txt' LABEL_PATH = './data/input/class.txt' BERT_PAD_ID = 0 TEXT_LEN = 35
2、统计句子长度
from config import *
import matplotlib.pyplot as plt
def count_text_len():
text_len = []
with open(TRAIN_SAMPLE_PATH) as f:
for line in f.readlines():
text, _ = line.split('\t')
text_len.append(len(text))
plt.hist(text_len)
plt.show()
print(max(text_len))
if __name__ == '__main__':
count_text_len()
做完简单的数据预处理之后,下一步就要定义 Dataset 类,加载数据了。但是数据加载,涉及到 Bert 分词,可能有的同学对 Bert 的使用还不熟悉,所以下面两节课,我们对 Bert 的基本使用做一个简单介绍,已经会用的同学可以直接跳过。
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/438