首先,DataParallel是单进程,多线程,并且只能在单台机器上运行,而DistributedDataParallel是多进程,并且适用于单机和多机训练。 即使在单台机器上,DataParallel通常也比DistributedDataParallel慢,这是因为跨线程的 GIL 争用,每次迭代复制的模型以及分散输入和收集输出所带来的额外开销。
import os import sys import tempfile import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
defsetup(rank, world_size): if sys.platform == 'win32': # Distributed package only covers collective communications with Gloo # backend and FileStore on Windows platform. Set init_method parameter # in init_process_group to a local file. # Example init_method="file:///f:/libtmp/some_file" init_method="file:///{your local file path}"
# initialize the process group dist.init_process_group( "gloo", init_method=init_method, rank=rank, world_size=world_size ) else: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355'
# initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" if rank == 0: # All processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes. # Therefore, saving it in one process is sufficient. torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process # 0 saves it. dist.barrier() # configure map_location properly map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} ddp_model.load_state_dict( torch.load(CHECKPOINT_PATH, map_location=map_location))
# Not necessary to use a dist.barrier() to guard the file deletion below # as the AllReduce ops in the backward pass of DDP already served as # a synchronization.
optimizer.zero_grad() # outputs will be on dev1 outputs = ddp_mp_model(torch.randn(20, 10)) labels = torch.randn(20, 5).to(dev1) loss_fn(outputs, labels).backward() optimizer.step()
cleanup()
if __name__ == "__main__": n_gpus = torch.cuda.device_count() if n_gpus < 8: print(f"Requires at least 8 GPUs to run, but got {n_gpus}.") else: run_demo(demo_basic, 8) run_demo(demo_checkpoint, 8) run_demo(demo_model_parallel, 4)