乡下人产国偷v产偷v自拍,国产午夜片在线观看,婷婷成人亚洲综合国产麻豆,久久综合给合久久狠狠狠9

  • <output id="e9wm2"></output>
    <s id="e9wm2"><nobr id="e9wm2"><ins id="e9wm2"></ins></nobr></s>

    • 分享

      PyTorch Lightning工具學(xué)習(xí)

       520jefferson 2020-12-08

      來(lái)源 | GiantPandaCV

      編輯 | pprp

      【導(dǎo)讀】Pytorch Lightning是在Pytorch基礎(chǔ)上進(jìn)行封裝的庫(kù)(可以理解為keras之于tensorflow),為了讓用戶能夠脫離PyTorch一些繁瑣的細(xì)節(jié),專注于核心代碼的構(gòu)建,提供了許多實(shí)用工具,可以讓實(shí)驗(yàn)更加高效。本文將介紹安裝方法、設(shè)計(jì)邏輯、轉(zhuǎn)化的例子等內(nèi)容。

      PyTorch Lightning中提供了以下比較方便的功能:

      • multi-GPU訓(xùn)練
      • 半精度訓(xùn)練
      • TPU 訓(xùn)練
      • 將訓(xùn)練細(xì)節(jié)進(jìn)行抽象,從而可以快速迭代
      Pytorch Lightning

      1. 簡(jiǎn)單介紹

      PyTorch lightning 是為AI相關(guān)的專業(yè)的研究人員、研究生、博士等人群開發(fā)的。PyTorch就是William Falcon在他的博士階段創(chuàng)建的,目標(biāo)是讓AI研究擴(kuò)展性更強(qiáng),忽略一些耗費(fèi)時(shí)間的細(xì)節(jié)。

      目前PyTorch Lightning庫(kù)已經(jīng)有了一定的影響力,star已經(jīng)1w+,同時(shí)有超過1千多的研究人員在一起維護(hù)這個(gè)框架。

      PyTorch Lightning庫(kù)

      同時(shí)PyTorch Lightning也在隨著PyTorch版本的更新也在不停迭代。

      版本支持情況

      官方文檔也有支持,正在不斷更新:

      官方文檔

      下面介紹一下如何安裝。

      2. 安裝方法

      Pytorch Lightning安裝非常方便,推薦使用conda環(huán)境進(jìn)行安裝。

      source activate you_env
      pip install pytorch-lightning

      或者直接用pip安裝:

      pip install pytorch-lightning

      或者通過conda安裝:

      conda install pytorch-lightning -c conda-forge

      3. Lightning的設(shè)計(jì)思想

      Lightning將大部分AI相關(guān)代碼分為三個(gè)部分:

      • 研究代碼,主要是模型的結(jié)構(gòu)、訓(xùn)練等部分。被抽象為L(zhǎng)ightningModule類。

      • 工程代碼,這部分代碼重復(fù)性強(qiáng),比如16位精度,分布式訓(xùn)練。被抽象為Trainer類。

      • 非必要代碼,這部分代碼和實(shí)驗(yàn)沒有直接關(guān)系,不加也可以,加上可以輔助,比如梯度檢查,log輸出等。被抽象為Callbacks類。

      Lightning將研究代碼劃分為以下幾個(gè)組件:

      • 模型
      • 數(shù)據(jù)處理
      • 損失函數(shù)
      • 優(yōu)化器

      以上四個(gè)組件都將集成到LightningModule類中,是在Module類之上進(jìn)行了擴(kuò)展,進(jìn)行了功能性補(bǔ)充,比如原來(lái)優(yōu)化器使用在main函數(shù)中,是一種面向過程的用法,現(xiàn)在集成到LightningModule中,作為一個(gè)類的方法。

      4. LightningModule生命周期

      這部分參考了https://zhuanlan.zhihu.com/p/120331610 和 官方文檔 https://pytorch-lightning./en/latest/trainer.html

      在這個(gè)模塊中,將PyTorch代碼按照五個(gè)部分進(jìn)行組織:

      • Computations(init) 初始化相關(guān)計(jì)算
      • Train Loop(training_step) 每個(gè)step中執(zhí)行的代碼
      • Validation Loop(validation_step) 在一個(gè)epoch訓(xùn)練完以后執(zhí)行Valid
      • Test Loop(test_step) 在整個(gè)訓(xùn)練完成以后執(zhí)行Test
      • Optimizer(configure_optimizers) 配置優(yōu)化器等

      展示一個(gè)最簡(jiǎn)代碼:

      >>> import pytorch_lightning as pl
      >>> class LitModel(pl.LightningModule):
      ...
      ...     def __init__(self):
      ...         super().__init__()
      ...         self.l1 = torch.nn.Linear(28 * 28, 10)
      ...
      ...     def forward(self, x):
      ...         return torch.relu(self.l1(x.view(x.size(0), -1)))
      ...
      ...     def training_step(self, batch, batch_idx):
      ...         x, y = batch
      ...         y_hat = self(x)
      ...         loss = F.cross_entropy(y_hat, y)
      ...         return loss
      ...
      ...     def configure_optimizers(self):
      ...         return torch.optim.Adam(self.parameters(), lr=0.02)

      那么整個(gè)生命周期流程是如何組織的?

      4.1 準(zhǔn)備工作

      這部分包括LightningModule的初始化、準(zhǔn)備數(shù)據(jù)、配置優(yōu)化器。每次只執(zhí)行一次,相當(dāng)于構(gòu)造函數(shù)的作用。

      • __init__()(初始化 LightningModule )
      • prepare_data() (準(zhǔn)備數(shù)據(jù),包括下載數(shù)據(jù)、預(yù)處理等等)
      • configure_optimizers() (配置優(yōu)化器)

      4.2 測(cè)試 驗(yàn)證部分

      實(shí)際運(yùn)行代碼前,會(huì)隨即初始化模型,然后運(yùn)行一次驗(yàn)證代碼,這樣可以防止在你訓(xùn)練了幾個(gè)epoch之后要進(jìn)行Valid的時(shí)候發(fā)現(xiàn)驗(yàn)證部分出錯(cuò)。主要測(cè)試下面幾個(gè)函數(shù):

      • val_dataloader()
      • validation_step()
      • validation_epoch_end()

      4.3 加載數(shù)據(jù)

      調(diào)用以下方法進(jìn)行加載數(shù)據(jù)。

      • train_dataloader()
      • val_dataloader()

      4.4 訓(xùn)練

      • 每個(gè)batch的訓(xùn)練被稱為一個(gè)step,故先運(yùn)行train_step函數(shù)。

      • 當(dāng)經(jīng)過多個(gè)batch, 默認(rèn)49個(gè)step的訓(xùn)練后,會(huì)進(jìn)行驗(yàn)證,運(yùn)行validation_step函數(shù)。

      • 當(dāng)完成一個(gè)epoch的訓(xùn)練以后,會(huì)對(duì)整個(gè)epoch結(jié)果進(jìn)行驗(yàn)證,運(yùn)行validation_epoch_end函數(shù)

      • (option)如果需要的話,可以調(diào)用測(cè)試部分代碼:

        • test_dataloader()
        • test_step()
        • test_epoch_end()

      5. 示例

      以MNIST為例,將PyTorch版本代碼轉(zhuǎn)為PyTorch Lightning。

      5.1 PyTorch版本訓(xùn)練MNIST

      對(duì)于一個(gè)PyTorch的代碼來(lái)說,一般是這樣構(gòu)建網(wǎng)絡(luò)(源碼來(lái)自PyTorch中的example庫(kù))。

      class Net(nn.Module):
          def __init__(self):
              super(Net, self).__init__()
              self.conv1 = nn.Conv2d(1, 32, 3, 1)
              self.conv2 = nn.Conv2d(32, 64, 3, 1)
              self.dropout1 = nn.Dropout(0.25)
              self.dropout2 = nn.Dropout(0.5)
              self.fc1 = nn.Linear(9216, 128)
              self.fc2 = nn.Linear(128, 10)

          def forward(self, x):
              x = self.conv1(x)
              x = F.relu(x)
              x = self.conv2(x)
              x = F.relu(x)
              x = F.max_pool2d(x, 2)
              x = self.dropout1(x)
              x = torch.flatten(x, 1)
              x = self.fc1(x)
              x = F.relu(x)
              x = self.dropout2(x)
              x = self.fc2(x)
              output = F.log_softmax(x, dim=1)
              return output

      還有兩個(gè)主要工作是構(gòu)建訓(xùn)練函數(shù)和測(cè)試函數(shù)。

      在訓(xùn)練函數(shù)中需要完成:

      • 數(shù)據(jù)獲取 data, target = data.to(device), target.to(device)
      • 清空優(yōu)化器梯度 optimizer.zero_grad()
      • 前向傳播 output = model(data)
      • 計(jì)算損失函數(shù) loss = F.nll_loss(output, target)
      • 反向傳播 loss.backward()
      • 優(yōu)化器進(jìn)行單次優(yōu)化 optimizer.step()
      def train(args, model, device, train_loader, optimizer, epoch):
          model.train()
          for batch_idx, (data, target) in enumerate(train_loader):
              data, target = data.to(device), target.to(device)
              optimizer.zero_grad()
              output = model(data)
              loss = F.nll_loss(output, target)
              loss.backward()
              optimizer.step()
              if batch_idx % args.log_interval == 0:
                  print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                      epoch, batch_idx * len(data), len(train_loader.dataset),
                      100. * batch_idx / len(train_loader), loss.item()))
                  if args.dry_run:
                      break

      def test(model, device, test_loader):
          model.eval()
          test_loss = 0
          correct = 0
          with torch.no_grad():
              for data, target in test_loader:
                  data, target = data.to(device), target.to(device)
                  output = model(data)
                  test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                  pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                  correct += pred.eq(target.view_as(pred)).sum().item()

          test_loss /= len(test_loader.dataset)

          print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
              test_loss, correct, len(test_loader.dataset),
              100. * correct / len(test_loader.dataset)))

      其他部分比如數(shù)據(jù)加載、數(shù)據(jù)增廣、優(yōu)化器、訓(xùn)練流程都是在main中執(zhí)行的,采用的是一種面向過程的方法。

      def main():
          # Training settings
          parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
          parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                              help='input batch size for training (default: 64)')
          parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                              help='input batch size for testing (default: 1000)')
          parser.add_argument('--epochs', type=int, default=14, metavar='N',
                              help='number of epochs to train (default: 14)')
          parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                              help='learning rate (default: 1.0)')
          parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                              help='Learning rate step gamma (default: 0.7)')
          parser.add_argument('--no-cuda', action='store_true', default=False,
                              help='disables CUDA training')
          parser.add_argument('--dry-run', action='store_true', default=False,
                              help='quickly check a single pass')
          parser.add_argument('--seed', type=int, default=1, metavar='S',
                              help='random seed (default: 1)')
          parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                              help='how many batches to wait before logging training status')
          parser.add_argument('--save-model', action='store_true', default=False,
                              help='For Saving the current Model')
          args = parser.parse_args()
          use_cuda = not args.no_cuda and torch.cuda.is_available()

          torch.manual_seed(args.seed)

          device = torch.device('cuda' if use_cuda else 'cpu')

          train_kwargs = {'batch_size': args.batch_size}
          test_kwargs = {'batch_size': args.test_batch_size}
          if use_cuda:
              cuda_kwargs = {'num_workers': 1,
                             'pin_memory': True,
                             'shuffle': True}
              train_kwargs.update(cuda_kwargs)
              test_kwargs.update(cuda_kwargs)

          transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
              ])
          dataset1 = datasets.MNIST('../data', train=True, download=True,
                             transform=transform)
          dataset2 = datasets.MNIST('../data', train=False,
                             transform=transform)
          train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
          test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

          model = Net().to(device)
          optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

          scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
          for epoch in range(1, args.epochs + 1):
              train(args, model, device, train_loader, optimizer, epoch)
              test(model, device, test_loader)
              scheduler.step()

          if args.save_model:
              torch.save(model.state_dict(), 'mnist_cnn.pt')

      5.2 Lightning版本訓(xùn)練MNIST

      第一部分,也就是歸為研究代碼,主要是模型的結(jié)構(gòu)、訓(xùn)練等部分。被抽象為L(zhǎng)ightningModule類。

      class LitClassifier(pl.LightningModule):
          def __init__(self, hidden_dim=128, learning_rate=1e-3):
              super().__init__()
              self.save_hyperparameters()

              self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
              self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

          def forward(self, x):
              x = x.view(x.size(0), -1)
              x = torch.relu(self.l1(x))
              x = torch.relu(self.l2(x))
              return x

          def training_step(self, batch, batch_idx):
              x, y = batch
              y_hat = self(x)
              loss = F.cross_entropy(y_hat, y)
              return loss

          def validation_step(self, batch, batch_idx):
              x, y = batch
              y_hat = self(x)
              loss = F.cross_entropy(y_hat, y)
              self.log('valid_loss', loss)

          def test_step(self, batch, batch_idx):
              x, y = batch
              y_hat = self(x)
              loss = F.cross_entropy(y_hat, y)
              self.log('test_loss', loss)

          def configure_optimizers(self):
              return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

          @staticmethod
          def add_model_specific_args(parent_parser):
              parser = ArgumentParser(parents=[parent_parser], add_help=False)
              parser.add_argument('--hidden_dim', type=int, default=128)
              parser.add_argument('--learning_rate', type=float, default=0.0001)
              return parser

      可以看出,和PyTorch版本最大的不同之處在于多了幾個(gè)流程處理函數(shù):

      • training_step,相當(dāng)于訓(xùn)練過程中處理一個(gè)batch的內(nèi)容
      • validation_step,相當(dāng)于驗(yàn)證過程中處理一個(gè)batch的內(nèi)容
      • test_step, 同上
      • configure_optimizers, 這部分用于處理optimizer和scheduler
      • add_module_specific_args代表這部分控制的是與模型相關(guān)的參數(shù)

      除此以外,main函數(shù)主要有以下幾個(gè)部分:

      • args參數(shù)處理
      • data部分
      • model部分
      • 訓(xùn)練部分
      • 測(cè)試部分
      def cli_main():
          pl.seed_everything(1234) # 這個(gè)是用于固定seed用

          # args
          parser = ArgumentParser()
          parser = pl.Trainer.add_argparse_args(parser)
          parser = LitClassifier.add_model_specific_args(parser)
          parser = MNISTDataModule.add_argparse_args(parser)
          args = parser.parse_args()

          # data
          dm = MNISTDataModule.from_argparse_args(args)

          # model
          model = LitClassifier(args.hidden_dim, args.learning_rate)

          # training
          trainer = pl.Trainer.from_argparse_args(args)
          trainer.fit(model, datamodule=dm)

          result = trainer.test(model, datamodule=dm)
          pprint(result)

      可以看出Lightning版本的代碼代碼量略低于PyTorch版本,但是同時(shí)將一些細(xì)節(jié)忽略了,比如訓(xùn)練的具體流程直接使用fit搞定,這樣不會(huì)出現(xiàn)忘記清空optimizer等低級(jí)錯(cuò)誤。

      6. 評(píng)價(jià)

      總體來(lái)說,PyTorch Lightning是一個(gè)發(fā)展迅速的框架,如同fastai、keras、ignite等二次封裝的框架一樣,雖然易用性得到了提升,讓用戶可以通過更短的代碼完成任務(wù),但是遇到錯(cuò)誤的時(shí)候,往往就需要查看API甚至涉及框架源碼才能夠解決。前者降低門檻,后者略微提升了門檻。

      筆者使用這個(gè)框架大概一周了,從使用者角度來(lái)談?wù)剝?yōu)缺點(diǎn):

      6.1 優(yōu)點(diǎn)

      • 簡(jiǎn)化了部分代碼,之前如果要轉(zhuǎn)到GPU上,需要用to(device)方法判斷,然后轉(zhuǎn)過去。有了PyTorch lightning的幫助,可以自動(dòng)幫你處理,通過設(shè)置trainer中的gpus參數(shù)即可。
      • 提供了一些有用的工具,比如混合精度訓(xùn)練、分布式訓(xùn)練、Horovod
      • 代碼移植更加容易
      • API比較完善,大部分都有例子,少部分講的不夠詳細(xì)。
      • 社區(qū)還是比較活躍的,如果有問題,可以在issue中提問。
      • 實(shí)驗(yàn)結(jié)果整理的比較好,將每次實(shí)驗(yàn)劃分為version 0-n,同時(shí)可以用tensorboard比較多個(gè)實(shí)驗(yàn),非常友好。

      6.2 缺點(diǎn)

      • 引入了一些新的概念,進(jìn)一步加大了使用者的學(xué)習(xí)成本,比如pl_bolts
      • 很多原本習(xí)慣于在Pytorch中使用的功能,在PyTorch Lightning中必須查API才能使用,比如我想用scheduler,就需要去查API,然后發(fā)現(xiàn)在configure_optimizers函數(shù)中實(shí)現(xiàn),然后模仿demo實(shí)現(xiàn),因此也帶來(lái)了一定的門檻。
      • 有些報(bào)錯(cuò)比較迷,筆者曾遇到過執(zhí)行的時(shí)候發(fā)現(xiàn)多線程出問題,比較難以排查,最后通過更改distributed_backend得到了解決。遇到新的坑要去API里找答案,如果沒有解決繼續(xù)去Issue里找答案。

      7. 參考

      • 【1】 https://zhuanlan.zhihu.com/p/120331610

      • 【2】https://pytorch-lightning./en/latest/introduction_guide.html

      • 【3】https://github.com/pytorch/examples/blob/master/mnist/main.py

      • 【4】 https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py

        本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點(diǎn)。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購(gòu)買等信息,謹(jǐn)防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請(qǐng)點(diǎn)擊一鍵舉報(bào)。
        轉(zhuǎn)藏 分享 獻(xiàn)花(0

        0條評(píng)論

        發(fā)表

        請(qǐng)遵守用戶 評(píng)論公約

        類似文章 更多