3.2 – 区分类型 (分类 Classification)
这次我们也是用最简单的途径来看看神经网络是怎么进行事物的分类.

建立数据集
我们创建一些假数据来模拟真实的情况. 比如两个二次分布的数据, 不过他们的均值都不一样.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| import torch from torch.autograd import Variable import matplotlib.pyplot as plt
n_data = torch.ones(100, 2) x0 = torch.normal(2*n_data, 1) y0 = torch.zeros(100) x1 = torch.normal(-2*n_data, 1) y1 = torch.ones(100)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor) y = torch.cat((y0, y1), ).type(torch.LongTensor)
x, y = Variable(x), Variable(y)
plt.scatter(x.data.numpy(), y.data.numpy()) plt.show()
|
建立神经网络
建立一个神经网络我们可以直接运用 torch 中的体系. 先定义所有的层属性( init() ), 然后再一层层搭建( forward(x) )层于层的关系链接. 这个和我们在前面 regression 的时候的神经网络基本没差. 建立关系的时候, 我们会用到激励函数.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| import torch import torch.nn.functional as F
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.out = torch.nn.Linear(n_hidden, n_output)
def forward(self, x): x = F.relu(self.hidden(x)) x = self.out(x) return x
net = Net(n_feature=2, n_hidden=10, n_output=2)
print(net) """ Net ( (hidden): Linear (2 -> 10) (out): Linear (10 -> 2) ) """
|
训练网络
训练的步骤很简单, 如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()
for t in range(100): out = net(x)
loss = loss_func(out, y)
optimizer.zero_grad() loss.backward() optimizer.step()
|
可视化训练过程
为了可视化整个训练的过程, 更好的理解是如何训练, 我们如下操作:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| import matplotlib.pyplot as plt
plt.ion() plt.show()
for t in range(100):
... loss.backward() optimizer.step()
if t % 2 == 0: plt.cla() prediction = torch.max(F.softmax(out), 1)[1] pred_y = prediction.data.numpy().squeeze() target_y = y.data.numpy() plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap=\'RdYlGn\') accuracy = sum(pred_y == target_y)/200 # 预测中有多少和真实值一样 plt.text(1.5, -4, \'Accuracy=%.2f\' % accuracy, fontdict={\'size\': 20, \'color\': \'red\'}) plt.pause(0.1)
plt.ioff() # 停止画图 plt.show()
|

所以这也就是在我 github 代码 中的每一步的意义啦.
文章来源:莫烦