web-dev-qa-db-ja.com

なぜzero_grad()を明示的に呼び出す必要があるのですか?

PyTorchの勾配を明示的にゼロにする必要があるのはなぜですか? loss.backward()が呼び出されたときにグラデーションをゼロにできないのはなぜですか?グラフ上に勾配を維持し、ユーザーに明示的に勾配をゼロにするよう要求することにより、どのようなシナリオが提供されますか?

38
Wasi Ahmad

zero_grad()の後で(勾配が計算されるとき)、loss.backward()を使用して勾配降下を進める必要があるため、明示的にoptimizer.step()を呼び出す必要があります。より具体的には、これらの2つの操作loss.backward()optimizer.step()は分離されており、optimizer.step()は計算されたばかりの勾配を必要とするため、勾配は自動的にゼロになりません。

さらに、いくつかのバッチ間で勾配を蓄積する必要がある場合があります。そのためには、単にbackwardを複数回呼び出して、一度最適化するだけです。

40
danche

PyTorchの現在のセットアップのユースケースがあります。

すべてのステップで予測を行うリカレントニューラルネットワーク(RNN)を使用している場合、時間を遡って勾配を蓄積できるハイパーパラメーターが必要になる場合があります。時間ステップごとに勾配をゼロにしないことで、興味深く斬新な方法で逆伝播時間(BPTT)を使用できます。

BPTTまたはRNNに関する詳細情報が必要な場合は、記事 Recurrent Neural Networks Tutorial、Part 3 – Backpropagation Through Time and Vanishing Gradients または リカレントニューラルネットワークの不合理な有効性

5
twrichar

.step()を呼び出す前に勾配をそのままにしておくと、複数のバッチに勾配を蓄積したい場合に便利です(他の人が述べたように)。

また、SGDにモメンタムを実装したい場合に.step()を呼び出すためにafterに便利です。また、他のさまざまな方法が前の更新の勾配の値に依存する場合があります。

2

PyTorchにはサイクルがあります:

  • 出力または_y_hat_を入力から取得すると転送し、
  • loss = loss_fn(y_hat, y)での損失の計算
  • _loss.backward_勾配を計算するとき
  • _optimizer.step_パラメータを更新するとき

またはコード内:

_for mb in range(10): # 10 mini batches
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
_

適切なステップである_optimizer.step_の後、または次のbackward()勾配が累積する直前の勾配をクリアしない場合。累積を示す例を次に示します。

_import torch
w = torch.Rand(5)
w.requires_grad_()
print(w) 
s = w.sum() 
s.backward()
print(w.grad) # tensor([1., 1., 1., 1., 1.])
s.backward()
print(w.grad) # tensor([2., 2., 2., 2., 2.])
s.backward()
print(w.grad) # tensor([3., 3., 3., 3., 3.])
s.backward()
print(w.grad) # tensor([4., 4., 4., 4., 4.])
_

loss.backward() にはこれを指定する方法がありません。

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)

指定できるすべてのオプションから、グラデーションを手動でゼロにする方法はありません。前のミニ例のように:

_w.grad.zero_()
_

zero_grad()(明らかに以前の勾配)を使用して毎回backward()を実行し、 _preserve_grads=True_ で卒業を維持することについていくつかの議論がありましたが、これは決してありません命に。

2
prosti