defforward(self, x, h_state): r_out, h_state = self.rnn(x, h_state) r_out_reshaped = r_out.view(-1, HIDDEN_SIZE) # to 2D data outs = self.linear_layer(r_out_reshaped) outs = outs.view(-1, TIME_STEP, INPUT_SIZE) # to 3D data
训练
下面的代码就能实现动图的效果啦~开心, 可以看出, 我们使用 x 作为输入的 sin 值, 然后 y作为想要拟合的输出, cos 值. 因为他们两条曲线是存在某种关系的, 所以我们就能用 sin 来预测 cos. rnn 会理解他们的关系, 并用里面的参数分析出来这个时刻 sin 曲线上的点如何对应上 cos 曲线上的点.
loss = loss_func(prediction, y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients