# Mnist digits dataset train_data = torchvision.datasets.MNIST( root=\'./mnist/\', train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, # download it if you don\'t have it )
loss = loss_func(decoded, b_y) # mean square error optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients
画3D图
3D 的可视化图挺有趣的, 还能挪动观看, 更加直观, 好理解.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# 要观看的数据 view_data = Variable(train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.) encoded_data, _ = autoencoder(view_data) # 提取压缩的特征值 fig = plt.figure(2) ax = Axes3D(fig) # 3D 图 # x, y, z 的数据值 X = encoded_data.data[:, 0].numpy() Y = encoded_data.data[:, 1].numpy() Z = encoded_data.data[:, 2].numpy() values = train_data.train_labels[:200].numpy() # 标签值 for x, y, z, s inzip(X, Y, Z, values): c = cm.rainbow(int(255*s/9)) # 上色 ax.text(x, y, z, s, backgroundcolor=c) # 标位子 ax.set_xlim(X.min(), X.max()) ax.set_ylim(Y.min(), Y.max()) ax.set_zlim(Z.min(), Z.max()) plt.show()