parser.add_argument('--world_size', default=2, help='Number of workers') parser.add_argument('--log_interval', default=1, help='Log every log_interval episodes') parser.add_argument('--gamma', default=0.1, help='how much to value future rewards') parser.add_argument('--seed', default=1, help='random seed for reproducibility') args = parser.parse_args()
defrun_episode(self, agent_rref, n_steps): state, ep_reward = self.env.reset(), 0 for step inrange(n_steps): # send the state to the agent to get an action action = _remote_method(Agent.select_action, agent_rref, self.id, state)
# apply the action to the environment, and get the reward state, reward, done, _ = self.env.step(action)
# report the reward to the agent for training purpose _remote_method(Agent.report_reward, agent_rref, self.id, reward)
import torch import torch.distributed.rpc as rpc import torch.optim as optim from torch.distributed.rpc import RRef, rpc_async, remote from torch.distributions import Categorical
classAgent: ... defrun_episode(self, n_steps=0): futs = [] for ob_rref inself.ob_rrefs: # make async RPC to kick off an episode on all observers futs.append( rpc_async( ob_rref.owner(), _call_method, args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps) ) )
# wait until all obervers have finished this episode for fut in futs: fut.wait()
classAgent: ... deffinish_episode(self): # joins probs and rewards from different observers into lists R, probs, rewards = 0, [], [] for ob_id inself.rewards: probs.extend(self.saved_log_probs[ob_id]) rewards.extend(self.rewards[ob_id])
# use the minimum observer reward to calculate the running reward min_reward = min([sum(self.rewards[ob_id]) for ob_id inself.rewards]) self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
# clear saved probs and rewards for ob_id inself.rewards: self.rewards[ob_id] = [] self.saved_log_probs[ob_id] = []
policy_loss, returns = [], [] for r in rewards[::-1]: R = r + args.gamma * R returns.insert(0, R) returns = torch.tensor(returns) returns = (returns - returns.mean()) / (returns.std() + self.eps) for log_prob, R inzip(probs, returns): policy_loss.append(-log_prob * R) self.optimizer.zero_grad() policy_loss = torch.cat(policy_loss).sum() policy_loss.backward() self.optimizer.step() return min_reward
if agent.running_reward > agent.reward_threshold: print("Solved! Running reward is now {}!".format(agent.running_reward)) break else: # other ranks are the observer rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size) # observers passively waiting for instructions from the agent
# block until all rpcs finish, and shutdown the RPC instance rpc.shutdown()
Episode 10 Last reward: 26.00 Average reward: 10.01 Episode 20 Last reward: 16.00 Average reward: 11.27 Episode 30 Last reward: 49.00 Average reward: 18.62 Episode 40 Last reward: 45.00 Average reward: 26.09 Episode 50 Last reward: 44.00 Average reward: 30.03 Episode 60 Last reward: 111.00 Average reward: 42.23 Episode 70 Last reward: 131.00 Average reward: 70.11 Episode 80 Last reward: 87.00 Average reward: 76.51 Episode 90 Last reward: 86.00 Average reward: 95.93 Episode 100 Last reward: 13.00 Average reward: 123.93 Episode 110 Last reward: 33.00 Average reward: 91.39 Episode 120 Last reward: 73.00 Average reward: 76.38 Episode 130 Last reward: 137.00 Average reward: 88.08 Episode 140 Last reward: 89.00 Average reward: 104.96 Episode 150 Last reward: 97.00 Average reward: 98.74 Episode 160 Last reward: 150.00 Average reward: 100.87 Episode 170 Last reward: 126.00 Average reward: 104.38 Episode 180 Last reward: 500.00 Average reward: 213.74 Episode 190 Last reward: 322.00 Average reward: 300.22 Episode 200 Last reward: 165.00 Average reward: 272.71 Episode 210 Last reward: 168.00 Average reward: 233.11 Episode 220 Last reward: 184.00 Average reward: 195.02 Episode 230 Last reward: 284.00 Average reward: 208.32 Episode 240 Last reward: 395.00 Average reward: 247.37 Episode 250 Last reward: 500.00 Average reward: 335.42 Episode 260 Last reward: 500.00 Average reward: 386.30 Episode 270 Last reward: 500.00 Average reward: 405.29 Episode 280 Last reward: 500.00 Average reward: 443.29 Episode 290 Last reward: 500.00 Average reward: 464.65 Solved! Running reward is now 475.3163778435275!
defforward(self, input, hidden): # pass input to the remote embedding table and fetch emb tensor back emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input) output, hidden = self.rnn(emb, hidden) # pass output to the rremote decoder and get the decoded output back decoded = _remote_method(Decoder.forward, self.decoder_rref, output) return decoded, hidden
在介绍分布式优化器之前,让我们添加一个辅助函数来生成模型参数的 RRef 列表,该列表将由分布式优化器使用。 在本地训练中,应用可以调用Module.parameters()来获取对所有参数张量的引用,并将其传递给本地优化器以进行后续更新。 但是,由于某些参数存在于远程计算机上,因此同一 API 在分布式训练方案中不起作用。 因此,分布式优化器不采用参数Tensors的列表,而是采用RRefs的列表,每个模型参数一个RRef用于本地和远程模型参数。 辅助函数非常简单,只需调用Module.parameters()并在每个参数上创建一个本地RRef。
1 2 3 4 5 6
def_parameter_rrefs(module): param_rrefs = [] for param in module.parameters(): param_rrefs.append(RRef(param)) return param_rrefs
classRNNModel(nn.Module): ... defparameter_rrefs(self): remote_params = [] # get RRefs of embedding table remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref)) # create RRefs for local parameters remote_params.extend(_parameter_rrefs(self.rnn)) # get RRefs of decoder remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref)) return remote_params
defget_next_batch(): for _ inrange(5): data = torch.LongTensor(batch, nindices) % ntoken target = torch.LongTensor(batch, ntoken) % nindices yield data, target
# train for 10 iterations for epoch inrange(10): for data, target in get_next_batch(): # create distributed autograd context with dist_autograd.context() as context_id: hidden[0].detach_() hidden[1].detach_() output, hidden = model(data, hidden) loss = criterion(output, target) # run distributed backward pass dist_autograd.backward(context_id, [loss]) # run distributed optimizer opt.step(context_id) # not necessary to zero grads since they are # accumulated into the distributed autograd context # which is reset every iteration. print("Training epoch {}".format(epoch))
最后,让我们添加一些粘合代码以启动参数服务器和训练器流程。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defrun_worker(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' if rank == 1: rpc.init_rpc("trainer", rank=rank, world_size=world_size) _run_trainer() else: rpc.init_rpc("ps", rank=rank, world_size=world_size) # parameter server do nothing pass
# block until all rpcs finish rpc.shutdown()
if __name__=="__main__": world_size = 2 mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)