PyTorch Lightning中提供了以下比較方便的功能:
- 將訓(xùn)練細(xì)節(jié)進(jìn)行抽象,從而可以快速迭代
Pytorch Lightning
PyTorch lightning 是為AI相關(guān)的專業(yè)的研究人員、研究生、博士等人群開發(fā)的。PyTorch就是William Falcon在他的博士階段創(chuàng)建的,目標(biāo)是讓AI研究擴(kuò)展性更強(qiáng),忽略一些耗費(fèi)時(shí)間的細(xì)節(jié)。
目前PyTorch Lightning庫(kù)已經(jīng)有了一定的影響力,star已經(jīng)1w+,同時(shí)有超過1千多的研究人員在一起維護(hù)這個(gè)框架。
PyTorch Lightning庫(kù)同時(shí)PyTorch Lightning也在隨著PyTorch版本的更新也在不停迭代。
版本支持情況官方文檔也有支持,正在不斷更新:
官方文檔下面介紹一下如何安裝。
Pytorch Lightning安裝非常方便,推薦使用conda環(huán)境進(jìn)行安裝。
source activate you_env
pip install pytorch-lightning
或者直接用pip安裝:
pip install pytorch-lightning
或者通過conda安裝:
conda install pytorch-lightning -c conda-forge
3. Lightning的設(shè)計(jì)思想
Lightning將大部分AI相關(guān)代碼分為三個(gè)部分:
研究代碼,主要是模型的結(jié)構(gòu)、訓(xùn)練等部分。被抽象為L(zhǎng)ightningModule類。
工程代碼,這部分代碼重復(fù)性強(qiáng),比如16位精度,分布式訓(xùn)練。被抽象為Trainer類。
非必要代碼,這部分代碼和實(shí)驗(yàn)沒有直接關(guān)系,不加也可以,加上可以輔助,比如梯度檢查,log輸出等。被抽象為Callbacks類。
Lightning將研究代碼劃分為以下幾個(gè)組件:
以上四個(gè)組件都將集成到LightningModule類中,是在Module類之上進(jìn)行了擴(kuò)展,進(jìn)行了功能性補(bǔ)充,比如原來(lái)優(yōu)化器使用在main函數(shù)中,是一種面向過程的用法,現(xiàn)在集成到LightningModule中,作為一個(gè)類的方法。
這部分參考了https://zhuanlan.zhihu.com/p/120331610 和 官方文檔 https://pytorch-lightning./en/latest/trainer.html
在這個(gè)模塊中,將PyTorch代碼按照五個(gè)部分進(jìn)行組織:
- Computations(init) 初始化相關(guān)計(jì)算
- Train Loop(training_step) 每個(gè)step中執(zhí)行的代碼
- Validation Loop(validation_step) 在一個(gè)epoch訓(xùn)練完以后執(zhí)行Valid
- Test Loop(test_step) 在整個(gè)訓(xùn)練完成以后執(zhí)行Test
- Optimizer(configure_optimizers) 配置優(yōu)化器等
展示一個(gè)最簡(jiǎn)代碼:
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(28 * 28, 10)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
... def training_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... loss = F.cross_entropy(y_hat, y)
... return loss
...
... def configure_optimizers(self):
... return torch.optim.Adam(self.parameters(), lr=0.02)
那么整個(gè)生命周期流程是如何組織的?
4.1 準(zhǔn)備工作
這部分包括LightningModule的初始化、準(zhǔn)備數(shù)據(jù)、配置優(yōu)化器。每次只執(zhí)行一次,相當(dāng)于構(gòu)造函數(shù)的作用。
__init__()
(初始化 LightningModule )prepare_data()
(準(zhǔn)備數(shù)據(jù),包括下載數(shù)據(jù)、預(yù)處理等等)configure_optimizers()
(配置優(yōu)化器)
4.2 測(cè)試 驗(yàn)證部分
實(shí)際運(yùn)行代碼前,會(huì)隨即初始化模型,然后運(yùn)行一次驗(yàn)證代碼,這樣可以防止在你訓(xùn)練了幾個(gè)epoch之后要進(jìn)行Valid的時(shí)候發(fā)現(xiàn)驗(yàn)證部分出錯(cuò)。主要測(cè)試下面幾個(gè)函數(shù):
4.3 加載數(shù)據(jù)
調(diào)用以下方法進(jìn)行加載數(shù)據(jù)。
4.4 訓(xùn)練
每個(gè)batch的訓(xùn)練被稱為一個(gè)step,故先運(yùn)行train_step函數(shù)。
當(dāng)經(jīng)過多個(gè)batch, 默認(rèn)49個(gè)step的訓(xùn)練后,會(huì)進(jìn)行驗(yàn)證,運(yùn)行validation_step函數(shù)。
當(dāng)完成一個(gè)epoch的訓(xùn)練以后,會(huì)對(duì)整個(gè)epoch結(jié)果進(jìn)行驗(yàn)證,運(yùn)行validation_epoch_end函數(shù)
(option)如果需要的話,可以調(diào)用測(cè)試部分代碼:
以MNIST為例,將PyTorch版本代碼轉(zhuǎn)為PyTorch Lightning。
5.1 PyTorch版本訓(xùn)練MNIST
對(duì)于一個(gè)PyTorch的代碼來(lái)說,一般是這樣構(gòu)建網(wǎng)絡(luò)(源碼來(lái)自PyTorch中的example庫(kù))。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
還有兩個(gè)主要工作是構(gòu)建訓(xùn)練函數(shù)和測(cè)試函數(shù)。
在訓(xùn)練函數(shù)中需要完成:
- 數(shù)據(jù)獲取
data, target = data.to(device), target.to(device)
- 清空優(yōu)化器梯度
optimizer.zero_grad()
- 前向傳播
output = model(data)
- 計(jì)算損失函數(shù)
loss = F.nll_loss(output, target)
- 優(yōu)化器進(jìn)行單次優(yōu)化
optimizer.step()
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
其他部分比如數(shù)據(jù)加載、數(shù)據(jù)增廣、優(yōu)化器、訓(xùn)練流程都是在main中執(zhí)行的,采用的是一種面向過程的方法。
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device('cuda' if use_cuda else 'cpu')
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), 'mnist_cnn.pt')
5.2 Lightning版本訓(xùn)練MNIST
第一部分,也就是歸為研究代碼,主要是模型的結(jié)構(gòu)、訓(xùn)練等部分。被抽象為L(zhǎng)ightningModule類。
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
可以看出,和PyTorch版本最大的不同之處在于多了幾個(gè)流程處理函數(shù):
- training_step,相當(dāng)于訓(xùn)練過程中處理一個(gè)batch的內(nèi)容
- validation_step,相當(dāng)于驗(yàn)證過程中處理一個(gè)batch的內(nèi)容
- configure_optimizers, 這部分用于處理optimizer和scheduler
- add_module_specific_args代表這部分控制的是與模型相關(guān)的參數(shù)
除此以外,main函數(shù)主要有以下幾個(gè)部分:
def cli_main():
pl.seed_everything(1234) # 這個(gè)是用于固定seed用
# args
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()
# data
dm = MNISTDataModule.from_argparse_args(args)
# model
model = LitClassifier(args.hidden_dim, args.learning_rate)
# training
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=dm)
result = trainer.test(model, datamodule=dm)
pprint(result)
可以看出Lightning版本的代碼代碼量略低于PyTorch版本,但是同時(shí)將一些細(xì)節(jié)忽略了,比如訓(xùn)練的具體流程直接使用fit搞定,這樣不會(huì)出現(xiàn)忘記清空optimizer等低級(jí)錯(cuò)誤。
總體來(lái)說,PyTorch Lightning是一個(gè)發(fā)展迅速的框架,如同fastai、keras、ignite等二次封裝的框架一樣,雖然易用性得到了提升,讓用戶可以通過更短的代碼完成任務(wù),但是遇到錯(cuò)誤的時(shí)候,往往就需要查看API甚至涉及框架源碼才能夠解決。前者降低門檻,后者略微提升了門檻。
筆者使用這個(gè)框架大概一周了,從使用者角度來(lái)談?wù)剝?yōu)缺點(diǎn):
6.1 優(yōu)點(diǎn)
- 簡(jiǎn)化了部分代碼,之前如果要轉(zhuǎn)到GPU上,需要用to(device)方法判斷,然后轉(zhuǎn)過去。有了PyTorch lightning的幫助,可以自動(dòng)幫你處理,通過設(shè)置trainer中的gpus參數(shù)即可。
- 提供了一些有用的工具,比如混合精度訓(xùn)練、分布式訓(xùn)練、Horovod
- API比較完善,大部分都有例子,少部分講的不夠詳細(xì)。
- 社區(qū)還是比較活躍的,如果有問題,可以在issue中提問。
- 實(shí)驗(yàn)結(jié)果整理的比較好,將每次實(shí)驗(yàn)劃分為version 0-n,同時(shí)可以用tensorboard比較多個(gè)實(shí)驗(yàn),非常友好。
6.2 缺點(diǎn)
- 引入了一些新的概念,進(jìn)一步加大了使用者的學(xué)習(xí)成本,比如pl_bolts
- 很多原本習(xí)慣于在Pytorch中使用的功能,在PyTorch Lightning中必須查API才能使用,比如我想用scheduler,就需要去查API,然后發(fā)現(xiàn)在configure_optimizers函數(shù)中實(shí)現(xiàn),然后模仿demo實(shí)現(xiàn),因此也帶來(lái)了一定的門檻。
- 有些報(bào)錯(cuò)比較迷,筆者曾遇到過執(zhí)行的時(shí)候發(fā)現(xiàn)多線程出問題,比較難以排查,最后通過更改distributed_backend得到了解決。遇到新的坑要去API里找答案,如果沒有解決繼續(xù)去Issue里找答案。
【1】 https://zhuanlan.zhihu.com/p/120331610
【2】https://pytorch-lightning./en/latest/introduction_guide.html
【3】https://github.com/pytorch/examples/blob/master/mnist/main.py
【4】 https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py