深度学习中,RNN网络在理解上有一些难度,本文以最简单的LSTM模型,实现MNIST数字识别,来帮助大家理解RNN的模型参数。因为基础的RNN模型在案例中表现不佳,故使用改进版的LSTM模型。
代码示例
1、加载数据集
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
mnist = MNIST('./mnist/', train=True, transform=transforms.ToTensor(), download=True)
loader = DataLoader(mnist, batch_size=100, shuffle=False)
# for train_x,train_y in loader:
# print(train_x.shape) # torch.Size([100, 1, 28, 28])
# print(train_y.shape) # torch.Size([100])
# exit()
2、定义模型
import torch.nn as nn
class Module(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.LSTM(28, 32, batch_first=True)
self.out = nn.Linear(32, 10)
def forward(self, x):
x, hn = self.rnn(x)
x = self.out(x[:, -1, :])
return x
module = Module()
# print(module)
3、模型训练
import torch
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(module.parameters(), lr=0.05)
for epoch in range(100):
for i, (x, y) in enumerate(loader):
# batch, timestamp, input
x = x.reshape(-1, x.shape[2], x.shape[3])
y_seq = module(x)
loss = loss_fn(y_seq, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印提示信息
if i%50 == 0:
y_hat = torch.argmax(y_seq, dim=1)
accuracy = (y_hat == y).sum() / len(y_hat)
print('epoch', epoch, 'loss:', loss.item(), 'accuracy:', accuracy)
4、模型测试
test_mnist = MNIST('./mnist/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(mnist, batch_size=100, shuffle=False)
for i, (x, y) in enumerate(loader):
# batch, timestamp, input
x = x.reshape(-1, x.shape[2], x.shape[3])
y_seq = module(x)
loss = loss_fn(y_seq, y)
# 打印提示信息
if i%50 == 0:
y_hat = torch.argmax(y_seq, dim=1)
accuracy = (y_hat == y).sum() / len(y_hat)
print('loss:', loss.item(), 'accuracy:', accuracy)
本项目的目的,是理解RNN网络的输入和输出参数,对于中间state,会在后续项目中补充。
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/248
- 下一篇:
- 手写AI算法之KMeans聚类算法