web-dev-qa-db-ja.com

PyTorchのカスタム損失関数

簡単な質問が3つあります。

  1. カスタム損失関数が区別できない場合はどうなりますか?エラーをpytorchするか、何か他のことをしますか?
  2. モデルの最終的な損失を表すカスタム関数で損失変数を宣言する場合、requires_grad = Trueその変数は?またはそれは問題ではありませんか?それが問題ではない場合、なぜですか?
  3. 私は人々が時々別のレイヤーを書いてforward関数で損失を計算するのを見てきました。関数とレイヤーのどちらを書くのが望ましいアプローチですか?どうして?

私の混乱を解決するために、これらの質問に対する明確で素晴らしい説明が必要です。助けてください。

12
Wasi Ahmad

やってみよう。

  1. これは、「微分不可能」の意味によって異なります。ここで意味のある最初の定義は、PyTorchが勾配を計算する方法を知らないということです。それでも勾配を計算しようとすると、エラーが発生します。次の2つのシナリオが考えられます。

    a)グラデーションが実装されていないカスタムPyTorchオペレーションを使用している。 torch.svd()。その場合、TypeErrorを取得します。

    _import torch
    from torch.autograd import Function
    from torch.autograd import Variable
    
    A = Variable(torch.randn(10,10), requires_grad=True)
    u, s, v = torch.svd(A) # raises TypeError
    _

    b)独自の操作を実装しましたが、backward()を定義していません。この場合、NotImplementedErrorを取得します。

    _class my_function(Function): # forgot to define backward()
    
        def forward(self, x):
            return 2 * x
    
    A = Variable(torch.randn(10,10))
    B = my_function()(A)
    C = torch.sum(B)
    C.backward() # will raise NotImplementedError
    _

    意味のある2番目の定義は、「数学的に微分不可能」です。明らかに、数学的に微分不可能な演算には、backward()メソッドが実装されていないか、適切な部分勾配があってはなりません。たとえば、torch.abs()メソッドが0で部分勾配0を返すbackward()を考えます。

    _A = Variable(torch.Tensor([-1,0,1]),requires_grad=True)
    B = torch.abs(A)
    B.backward(torch.Tensor([1,1,1]))
    A.grad.data
    _

    これらの場合は、PyTorchのドキュメントを直接参照し、それぞれの操作のbackward()メソッドを直接調べてください。

  2. それは問題ではありません。 _requires_grad_ isを使用すると、サブグラフの勾配の不要な計算を回避できます。勾配を必要とする操作への単一の入力がある場合、その出力も勾配を必要とします。逆に、すべての入力が勾配を必要としない場合のみ、出力も勾配を必要としません。すべての変数が勾配を必要としなかったサブグラフでは、逆方向の計算は決して行われません。

    いくつかのVariables(たとえば、nn.Module()のサブクラスのパラメーター)が存在する可能性が高いため、loss変数にも自動的に勾配が必要になります。ただし、_requires_grad_の動作(上記を参照)については、グラフのリーフ変数の_requires_grad_のみを変更できることに注意してください。

  3. すべてのカスタムPyTorch損失関数は、__Loss_のサブクラスである_nn.Module_のサブクラスです。 ここを参照してください。 この規則を守りたい場合は、カスタム損失関数を定義するときに__Loss_をサブクラス化する必要があります。一貫性とは別に、1つの利点は、ターゲット変数をAssertionErrorまたは_requires_grad = False_としてマークしていない場合、サブクラスがvolatileを発生させることです。別の利点は、損失関数をnn.Sequential()にネストできることです。これは、_nn.Module_であるためです。これらの理由から、このアプローチをお勧めします。

11
mexmex