defrun_worker(rank, world_size): r""" A wrapper function that initializes RPC, calls the function, and shuts down RPC. """ os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500'
# Rank 2 is master, 3 is ps and 0 and 1 are trainers. if rank == 2: rpc.init_rpc( "master", rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options)
# Build the embedding table on the ps. emb_rref = rpc.remote( "ps", torch.nn.EmbeddingBag, args=(NUM_EMBEDDINGS, EMBEDDING_DIM), kwargs={"mode": "sum"})
# Run the training loop on trainers. futs = [] for trainer_rank in [0, 1]: trainer_name = "trainer{}".format(trainer_rank) fut = rpc.rpc_async( trainer_name, _run_trainer, args=(emb_rref, rank)) futs.append(fut)
# Wait for all training to finish. for fut in futs: fut.wait() elif rank <= 1: # Initialize process group for Distributed DataParallel on trainers. dist.init_process_group( backend="gloo", rank=rank, world_size=2)
# Trainer just waits for RPCs from master. else: rpc.init_rpc( "ps", rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options) # parameter server do nothing pass
classHybridModel(torch.nn.Module): r""" The model consists of a sparse part and a dense part. The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is stored on the parameter server. The model holds a Remote Reference to the embedding table on the parameter server. """
def_retrieve_embedding_parameters(emb_rref): param_rrefs = [] for param in emb_rref.local_value().parameters(): param_rrefs.append(RRef(param)) return param_rrefs
def_run_trainer(emb_rref, rank): r""" Each trainer runs a forward pass which involves an embedding lookup on the parameter server and running nn.Linear locally. During the backward pass, DDP is responsible for aggregating the gradients for the dense part (nn.Linear) and distributed autograd ensures gradients updates are propagated to the parameter server. """
# Setup the model. model = HybridModel(emb_rref, rank)
# Retrieve all model parameters as rrefs for DistributedOptimizer.
# Train for 100 epochs for epoch inrange(100): # create distributed autograd context for indices, offsets, target in get_next_batch(rank): with dist_autograd.context() as context_id: output = model(indices, offsets) loss = criterion(output, target)
# Run distributed backward pass dist_autograd.backward(context_id, [loss])
# Tun distributed optimizer opt.step(context_id)
# Not necessary to zero grads as each iteration creates a different # distributed autograd context which hosts different grads print("Training done for epoch {}".format(epoch))