5.2 – GPU 加速运算
在 GPU 训练可以大幅提升运算速度. 而且 Torch 也有一套很好的 GPU 运算体系. 但是要强调的是:
用 GPU 训练 CNN
这份 GPU 的代码是依据之前这份CNN的代码修改的. 大概修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 CNN 也变成 GPU 能读的形式. 做法就是在后面加上 .cuda() , 很简单.
1 2 3 4 5 6 7
| ...
test_data = torchvision.datasets.MNIST(root=\'./mnist/\', train=False)
# !!!!!!!! 修改 test data 形式 !!!!!!!!! # test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)).type(torch.FloatTensor)[:2000].cuda()/255\. # Tensor on GPU test_y = test_data.test_labels[:2000].cuda()
|
再来把我们的 CNN 参数也变成 GPU 兼容形式.
1 2 3 4 5 6 7
| class CNN(nn.Module): ...
cnn = CNN()
cnn.cuda()
|
然后就是在 train 的时候, 将每次的training data 变成 GPU 形式. .cuda()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| for epoch ..: for step, ...: b_x = Variable(x).cuda() b_y = Variable(y).cuda()
...
if step % 50 == 0: test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze()
accuracy = torch.sum(pred_y == test_y) / test_y.size(0) ...
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() ... print(test_y[:10], \'real number\')
|
大功告成~
所以这也就是在我 github 代码 中的每一步的意义啦.
文章来源:莫烦