|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +import os |
| 4 | +import time |
| 5 | + |
| 6 | +from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +from torch.distributed import rpc |
| 11 | + |
| 12 | +from torch.distributed.pipeline.sync import Pipe |
| 13 | +from torch.distributed.pipeline.sync.utils import partition_model |
| 14 | +from torch.optim import Adam |
| 15 | +from torch.utils.data import DataLoader |
| 16 | + |
| 17 | + |
| 18 | +def sizeof_fmt(num, suffix="B"): |
| 19 | + for unit in ["", "Ki", "Mi", "Gi", "Ti"]: |
| 20 | + if abs(num) < 1024.0: |
| 21 | + return f"{num:3.2f}{unit}B" |
| 22 | + num /= 1024.0 |
| 23 | + |
| 24 | + |
| 25 | +def init_random_seed(seed: int): |
| 26 | + import numpy |
| 27 | + |
| 28 | + torch.manual_seed(seed) |
| 29 | + torch.cuda.manual_seed(seed) |
| 30 | + numpy.random.seed(seed) |
| 31 | + |
| 32 | + |
| 33 | +iteration_count = 0 |
| 34 | + |
| 35 | + |
| 36 | +class EmbeddingLayer(nn.Embedding): |
| 37 | + def __init__(self, ntoken, ninp, initrange): |
| 38 | + super().__init__(ntoken, ninp) |
| 39 | + self.ninp = ninp |
| 40 | + nn.init.uniform_(self.weight, -initrange, initrange) |
| 41 | + |
| 42 | + def forward(self, src): |
| 43 | + return super().forward(src) * math.sqrt(self.ninp) |
| 44 | + |
| 45 | + |
| 46 | +class PositionalEncodingLayer(nn.Module): |
| 47 | + def __init__(self, d_model, dropout=0.1, max_len=5000): |
| 48 | + super().__init__() |
| 49 | + self.dropout = nn.Dropout(p=dropout) |
| 50 | + |
| 51 | + pe = torch.zeros(max_len, d_model) |
| 52 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| 53 | + div_term = torch.exp( |
| 54 | + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| 55 | + ) |
| 56 | + pe[:, 0::2] = torch.sin(position * div_term) |
| 57 | + pe[:, 1::2] = torch.cos(position * div_term) |
| 58 | + pe = pe.unsqueeze(0).transpose(0, 1) |
| 59 | + self.register_buffer("pe", pe) |
| 60 | + |
| 61 | + def forward(self, x): |
| 62 | + x = x + self.pe[: x.size(0), :] |
| 63 | + return self.dropout(x) |
| 64 | + |
| 65 | + |
| 66 | +class TransformerDecoderLayer(nn.TransformerEncoderLayer): |
| 67 | + """Though this class inherits from torch.nn.TransformerEncoderLayer, |
| 68 | + it functions as a decoder in this model""" |
| 69 | + |
| 70 | + def __init__(self, ninp, nhead, nhid, droupout): |
| 71 | + super().__init__(ninp, nhead, nhid, droupout) |
| 72 | + self.src_mask = None |
| 73 | + |
| 74 | + def forward(self, src): |
| 75 | + global iteration_count |
| 76 | + iteration_count += 1 |
| 77 | + |
| 78 | + if self.src_mask is None or self.src_mask.size(0) != len(src): |
| 79 | + device = src.device |
| 80 | + mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) |
| 81 | + self.src_mask = mask |
| 82 | + |
| 83 | + return super().forward(src, self.src_mask) |
| 84 | + |
| 85 | + |
| 86 | +class LinearLayer(nn.Linear): |
| 87 | + def __init__(self, ninp, ntoken, initrange): |
| 88 | + super().__init__(ninp, ntoken) |
| 89 | + nn.init.zeros_(self.bias) |
| 90 | + nn.init.uniform_(self.weight, -initrange, initrange) |
| 91 | + |
| 92 | + |
| 93 | +class TransformerLMSequential(nn.Sequential): |
| 94 | + """A small language model based on the design of GPT-2 using nn.Sequential |
| 95 | + for compatibility with Pipe""" |
| 96 | + |
| 97 | + def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): |
| 98 | + layers = [ |
| 99 | + EmbeddingLayer(ntokens, ninp, initrange), |
| 100 | + PositionalEncodingLayer(ninp, dropout), |
| 101 | + ] |
| 102 | + for _ in range(ndecoder): |
| 103 | + layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) |
| 104 | + |
| 105 | + layers.append(LinearLayer(ninp, ntokens, initrange)) |
| 106 | + super().__init__(*layers) |
| 107 | + |
| 108 | + |
| 109 | +def make_model(args, device, ntokens): |
| 110 | + ninp = 2048 # embedding dimension |
| 111 | + nhid = ( |
| 112 | + 2048 # the dimension of the feedforward network model in nn.TransformerEncoder |
| 113 | + ) |
| 114 | + nhead = 32 # the number of heads in the multiheadattention models |
| 115 | + dropout = 0 |
| 116 | + initrange = 0.1 |
| 117 | + ndecoder = args.num_decoder_layers |
| 118 | + |
| 119 | + model = TransformerLMSequential( |
| 120 | + ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder |
| 121 | + ).to(device) |
| 122 | + |
| 123 | + criterion = nn.CrossEntropyLoss() |
| 124 | + lr = 0.01 # learning rate |
| 125 | + |
| 126 | + def make_adam(model): |
| 127 | + return Adam(model.parameters(), lr=lr) |
| 128 | + |
| 129 | + optimizer = make_adam |
| 130 | + |
| 131 | + return model, criterion, optimizer |
| 132 | + |
| 133 | + |
| 134 | +def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): |
| 135 | + model.train() |
| 136 | + |
| 137 | + vocab_size = 10000 |
| 138 | + total_loss = 0.0 |
| 139 | + start_time = time.time() |
| 140 | + word_counter = 0 |
| 141 | + |
| 142 | + optimizer = optimizer(model) |
| 143 | + |
| 144 | + def get_first_device(model): |
| 145 | + if model.devices: |
| 146 | + return model.devices[0] |
| 147 | + else: |
| 148 | + return torch.cuda.current_device() |
| 149 | + |
| 150 | + def get_last_device(model): |
| 151 | + if model.devices: |
| 152 | + return model.devices[-1] |
| 153 | + else: |
| 154 | + return torch.cuda.current_device() |
| 155 | + |
| 156 | + print( |
| 157 | + f"Number of parameters for model: {sum(p.numel() for p in model.parameters())}" |
| 158 | + ) |
| 159 | + for i, batch in enumerate(lm_dataloader): |
| 160 | + bi = batch["input"] |
| 161 | + if args.max_batch and i > args.max_batch: |
| 162 | + break |
| 163 | + optimizer.zero_grad() |
| 164 | + try: |
| 165 | + tmp = batch["input"].to(get_first_device(model)) |
| 166 | + output = model(tmp).local_value() |
| 167 | + except Exception as e: |
| 168 | + raise RuntimeError( |
| 169 | + f"training failed on {torch.distributed.get_rank()}" |
| 170 | + ) from e |
| 171 | + |
| 172 | + target = batch["target"].to(get_last_device(model)) |
| 173 | + output = output.to(target.device) |
| 174 | + |
| 175 | + loss = criterion(output.view(-1, vocab_size), target.view(-1)) |
| 176 | + loss.backward() |
| 177 | + del target |
| 178 | + del output |
| 179 | + |
| 180 | + torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) |
| 181 | + optimizer.step() |
| 182 | + |
| 183 | + total_loss += loss.item() |
| 184 | + log_interval = 1 |
| 185 | + word_counter += batch["ntokens"] |
| 186 | + if i % log_interval == 0 and i > 0: |
| 187 | + cur_loss = total_loss / log_interval |
| 188 | + elapsed = time.time() - start_time |
| 189 | + print( |
| 190 | + f"| batch {i:5d} | wps {word_counter / elapsed:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}" |
| 191 | + ) |
| 192 | + word_counter = 0 |
| 193 | + total_loss = 0 |
| 194 | + start_time = time.time() |
| 195 | + |
| 196 | + print("Peak memory usage for GPUs: ", end="") |
| 197 | + for i in range(len(model.devices)): |
| 198 | + print( |
| 199 | + f"cuda:{i}: {sizeof_fmt(torch.cuda.memory_stats(i)['allocated_bytes.all.peak'])}, ", |
| 200 | + end="", |
| 201 | + ) |
| 202 | + print() |
| 203 | + |
| 204 | + |
| 205 | +def generate_balance(num_devices, num_layers): |
| 206 | + balance = [] |
| 207 | + layers_assigned = 0 |
| 208 | + for i in range(num_devices): |
| 209 | + x = (num_layers - layers_assigned) / (num_devices - i) |
| 210 | + if x.is_integer(): |
| 211 | + balance.append(int(x)) |
| 212 | + layers_assigned += x |
| 213 | + else: |
| 214 | + balance.append(math.ceil(x)) |
| 215 | + layers_assigned += math.ceil(x) |
| 216 | + return balance |
| 217 | + |
| 218 | + |
| 219 | +def make_model_and_data(args, device): |
| 220 | + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 221 | + vocab_size = 10000 |
| 222 | + model, criterion, optimizer = make_model(args, device, vocab_size) |
| 223 | + lm_dataset = BenchmarkLMDataset() |
| 224 | + lm_dataloader = DataLoader( |
| 225 | + lm_dataset, |
| 226 | + batch_size=args.batch_size, |
| 227 | + shuffle=True, |
| 228 | + num_workers=0, |
| 229 | + collate_fn=collate_sentences_lm, |
| 230 | + ) |
| 231 | + return { |
| 232 | + "model": model, |
| 233 | + "criterion": criterion, |
| 234 | + "optimizer": optimizer, |
| 235 | + "data": lm_dataloader, |
| 236 | + "vocab_size": vocab_size, |
| 237 | + } |
| 238 | + |
| 239 | + |
| 240 | +def bench_single_process(args): |
| 241 | + os.environ.update({"MASTER_ADDR": args.host}) |
| 242 | + os.environ.update({"MASTER_PORT": "10638"}) |
| 243 | + |
| 244 | + rpc.init_rpc( |
| 245 | + "worker", |
| 246 | + rank=0, |
| 247 | + world_size=1, |
| 248 | + ) |
| 249 | + |
| 250 | + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 |
| 251 | + num_devices = min(args.num_devices, num_devices) |
| 252 | + assert num_devices > 0 |
| 253 | + init_random_seed(0) |
| 254 | + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 255 | + |
| 256 | + blob = make_model_and_data(args, None) |
| 257 | + model = blob["model"] |
| 258 | + |
| 259 | + balance = generate_balance(num_devices, len(model)) |
| 260 | + model = partition_model(model, balance) |
| 261 | + p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) |
| 262 | + del model |
| 263 | + del blob["model"] |
| 264 | + |
| 265 | + train( |
| 266 | + blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args |
| 267 | + ) |
| 268 | + |
| 269 | + |
| 270 | +parser = argparse.ArgumentParser(description="benchmark") |
| 271 | +parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") |
| 272 | +parser.add_argument( |
| 273 | + "--chunks", type=int, default=4, help="number of microbatches per batch" |
| 274 | +) |
| 275 | +parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") |
| 276 | +parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") |
| 277 | +parser.add_argument( |
| 278 | + "--num-decoder-layers", |
| 279 | + type=int, |
| 280 | + default=10, |
| 281 | + help="Number of decoder layers in the model", |
| 282 | +) |
| 283 | +parser.add_argument( |
| 284 | + "--checkpoint", |
| 285 | + default="except_last", |
| 286 | + choices=["always", "except_last", "never"], |
| 287 | + help="Checkpointing strategy for pipe", |
| 288 | +) |
| 289 | +parser.add_argument( |
| 290 | + "--num-devices", type=int, default=4, help="Number of GPU devices to use" |
| 291 | +) |
| 292 | + |
| 293 | +if __name__ == "__main__": |
| 294 | + args = parser.parse_args() |
| 295 | + print(f"Running benchmark with args: {args}") |
| 296 | + bench_single_process(args) |
0 commit comments