import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision # 数据库模块 import matplotlib.pyplot as plt
# Mnist 手写数字 train_data = torchvision.datasets.MNIST( root=\\'./mnist/\\', # 保存或者提取位置 train=True, # this is training data transform=torchvision.transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成 # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间 download=DOWNLOAD_MNIST, # 没下载就下载, 下载了就不用再下了 )
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing for epoch inrange(EPOCH): for step, (x, y) inenumerate(train_loader): # 分配 batch data, normalize x when iterate train_loader b_x = Variable(x) # batch x b_y = Variable(y) # batch y
output = cnn(b_x) # cnn output loss = loss_func(output, b_y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients