前面几篇文章,已经基本完全的介绍了模型定义的各种细节,包括模型定义、损失反向传播、权重参数更新等。但我们使用 Sequential 快速搭建的网络模型,只能处理简单的业务,如果碰到复杂的业务场景,继承 nn.Module 自定义模型处理类,将会是更好的选择。
代码示例
import torch
import torch.nn as nn
# 用类定义模型
class Net(nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
self.linear1 = nn.Linear(D_in, H)
self.relu = nn.ReLU(H)
self.linear2 = nn.Linear(H, D_out)
# 前向传播
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = Net(D_in, H, D_out)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss(reduction='sum')
for i in range(500):
y_hat = model(x)
loss = loss_fn(y_hat, y)
print(i, loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/311