Pytorch DDP 설정하는 법 (DDP baseline)
일단 아래는 baseline으로 쓸만한 데모 소스이다. (설명은 더 아래에 있다.) NLP 모델에서 pretraining과 finetuning을 번갈아 가며 할 수 있도록 설게했다. config.mode
는 pretrain과 finetune 중 하나, run_epoch 함수의 mode는 train/valid임에 유의. config.mode
는 당연히 필요 없다면 지우고 나머지 부분도 알아서 수정해가며 쓰면 된다.
AverageMeter은 내 기억이 맞다면 내가 Masked Autoencoder repository에서 가져온 것 같고, 각주에서는 또 이걸 다른 데서 가져왔다고 한다. utils.py 등에 넣어놓고 import하자.
# ============= train.py ==================
import logging, wandb, torch
import torch.distributed as dist, torch.nn.functional as F, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from types import SimpleNamespace
from tqdm import tqdm
from utils import AverageMeter, TqdmLoggingHandler
def run_epoch(model, optimizer, criterion, train_loader, valid_loader, epoch, mode, **kwargs):
loss_sum = acc = size = 0
loader = (train_loader if mode == 'train' else valid_loader)
if len(loader) == 0:
return
for data in tqdm(loader):
x, y = data
x = x.to(local_rank)
y = y.to(local_rank)
optimizer.zero_grad()
out = model(x)
if config.mode == 'pretrain':
y = model.emb(y)
loss = criterion(out, y)
if mode == 'train':
loss.backward()
optimizer.step()
loss_sum += loss.item()
if config.mode == 'finetune':
acc += torch.count_nonzero(out.argmax(axis=-1) == y).item()
size += len(x)
if local_rank == 0 and step.get() % 50 == 0:
wandb.log({'loss': loss}, step=step.count)
loss_mean = loss_sum / len(loader)
logging.info(f"Loss: {loss_mean}")
if local_rank == 0:
wandb.log({f'{mode}_loss': loss_mean, f'{mode}_acc': acc / size,
'lr': optimizer.param_groups[0]['lr'], 'epoch': epoch}, step=step.count)
def train(config):
logging.info(f"[ {config.mode} begin. ]")
model = MyModel(mode=config.mode)
model.to(local_rank)
if local_rank == 0:
wandb.init(project="2-5", name=config.mode)
wandb.watch(model)
wandb.config.update(config)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=0.1)
criterion = config.criterion
train_x, train_y, valid_x, valid_y = load_dataset(config.mode)
train_dataset = MyDataset(train_x, train_y, config)
valid_dataset = MyDataset(valid_x, valid_y, config)
logging.info(f'Loading data is finished! train: {len(train_dataset)}, valid: {len(valid_dataset)}')
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
train_loader = DataLoader(dataset=train_dataset, sampler=train_sampler, shuffle=False, batch_size=config.batch_size // world_size)
valid_loader = DataLoader(dataset=valid_dataset, sampler=valid_sampler, shuffle=False, batch_size=config.batch_size // world_size)
for epoch in range(1, 1 + config.max_epoch):
logging.info(f"Epoch {epoch}")
train_sampler.set_epoch(epoch)
valid_sampler.set_epoch(epoch)
model.train()
run_epoch(mode='train', **locals())
model.eval()
with torch.no_grad():
run_epoch(mode='valid', **locals())
scheduler.step()
if local_rank == 0:
torch.save(model.state_dict(), f'{config.mode}_{epoch}.pkl')
if local_rank == 0:
wandb.finish()
dist.barrier()
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
world_size = dist.get_world_size()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if local_rank == 0 else logging.WARNING,
handlers=[TqdmLoggingHandler()])
logging.info(f"Training begin. world_size: {world_size}")
config = SimpleNamespace()
config.max_epoch = 9
config.batch_size = 4096
config.lr = 1e-3
config.weight_decay = 0
config.step_size = 3
config.criterion = nn.MSELoss()
config.mode = 'pretrain'
step = AverageMeter()
train(config)
# =========== utils.py ==============
import logging
from tqdm import tqdm
class AverageMeter(object):
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val=0, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get(self):
self.count += 1
return self.count
class TqdmLoggingHandler(logging.StreamHandler):
"""Avoid tqdm progress bar interruption by logger's output to console"""
# see logging.StreamHandler.eval method:
# https://github.com/python/cpython/blob/d2e2534751fd675c4d5d3adc208bf4fc984da7bf/Lib/logging/__init__.py#L1082-L1091
# and tqdm.write method:
# https://github.com/tqdm/tqdm/blob/f86104a1f30c38e6f80bfd8fb16d5fcde1e7749f/tqdm/std.py#L614-L620
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg, end=self.terminator)
except RecursionError:
raise
except Exception:
self.handleError(record)
주의: validation 도중(with torch.no_grad() 내부) loss.backward()
를 사용하면 오류가 난다. 그래서 mode=='train'일때만 loss.backward()
를 실행해줘야 하며, 그렇게 구현했다.
실행 방법: 외부 파이썬 스크립트로 subprocess.run(f"torchrun --nproc_per_node={n_gpus} train.py", shell=True)
, (이 때 n_gpus
는 torch.cuda.device_count()
)를 실행하든 쉘에서 torchrun --nproc_per_node=(원하는 노드수, 보통은 gpu수) train.py
를 실행하든 하면 된다.
DDP가 실행이 되도록 옮길 때 핵심은 1) global context에서 dist.init_process_group(backend="nccl")
실행해줄것 2) DataLoader init 당시에 DistributedSampler을 넣어줄 것 3) epoch마다 sampler에서 set_epoch 설정해줄 것 4) 메인 프로세스에서만 돌아가는 것(예를 들어 print, 모델 저장 등)은 local_rank가 0일 때만 해줄 것 5) torchrun으로 위 스크립트 실행해줄 것
부가적으로, 만약 특정 지점 이후로 특정 프로세스만 앞서나가지 못하게 하고 싶다면 dist.barrier()
을 써줄것
참고로 DistributedSampler을 설정하면 DataLoader에서 shuffle=False로 설정해야 하고(shuffle=True로 설정하면 문제가 생기는지는 모르겠다. 일단 shuffle을 Sampler단에서 해줘서 false로 설정해도 괜찮다고는 들었다.), sampler에서 epoch마다 set_epoch를 설정하지 않으면 매 epoch에서 똑같은 순서로 데이터가 sampling된다. 데이터 크기가 크면 영향이 적겠으나 데이터 크기가 작으면 영향이 있을 수 있다.
또한 print는 logging 라이브러리를 이용해 local_rank==0일 때만 출력되도록 설정했다. ifmain 내부를 참고하자. logging Handler중 TqdmLoggingHandler는 tqdm과 logging 라이브러리를 동시에 사용하면 메시지가 섞이는 이슈 때문에 사용했고, 더 나은 방법이 있는지 모르겠다. 핸들러를 일일이 설정하는 건 귀찮잖아..
torch.multiprocessing인가? 이런 모듈을 이용해 여러 프로세스를 실행하도록 만드는 방법도 있는데, torchrun에서 일어나지 않는 버그가 일어나기도 한다고 들었다. (스택오버플로에서 봤다.) 개인적으로 양쪽 모두 구현해봤으나 위 포스트에 올린 쪽이 더 간편하다고 느껴 이대로 포스트를 올린다.