web-dev-qa-db-ja.com

警告:tensorflow:sample_weightモードが...から['...']に強制変換されました

.fit_generator()または.fit()を使用して画像分類子をトレーニングし、辞書をclass_weight=に引数として渡します。

TF1.xでエラーが発生したことはありませんが、2.1ではトレーニングを開始すると次の出力が表示されます。

WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']

...から['...']に強制変換するとはどういう意味ですか?

tensorflowのリポジトリに関するこの警告のソースは here であり、コメントは次のとおりです。

Sample_weight_modesをターゲット構造に強制変換しようとしました。これは、Modelが内部表現の出力を平坦化するという事実に暗黙的に依存しています。

44
jorijnsmit

これは偽のメッセージのようです。 TensorFlow 2.1にアップグレードした後も同じ警告メッセージが表示されますが、クラスの重みまたはサンプルの重みをまったく使用していません。私はこのようなタプルを返すジェネレータを使用します:

return inputs, targets

そして今、私はそれを次のように変更して警告を消しました:

return inputs, targets, [None]

これが関連するかどうかはわかりませんが、私のモデルでは3つの入力を使用しているため、inputs変数は実際には3つのnumpy配列のリストです。 targetsは単一のnumpy配列です。

いずれにせよ、それは単なる警告です。トレーニングはどちらの方法でも問題なく機能します。

TensorFlow 2.2の編集:

このバグはTensorFlow 2.2で修正されたようで、すばらしいです。ただし、上記の修正はTF 2.2では失敗します。これは、サンプルの重みの形状を取得しようとするためで、明らかにAttributeError: 'NoneType' object has no attribute 'shape'で失敗します。そのため、2.2にアップグレードするときに上記の修正を元に戻します。

6
jlh

これはテンソルフローのバグであり、デフォルトのパラメータsample_weight_mode=Noneを使用してmodel.compile()を呼び出し、次にsample_weightまたはclass_weightを指定してmodel.fit()を呼び出すと発生すると考えられます。

Tensorflowリポジトリから:

  • fit()は最終的に_process_training_inputs()を呼び出します
  • _process_training_inputs()setssample_weight_modes = [None]に基づいてmodel.sample_weight_mode = Noneを作成し、sample_weight_modes = [None]DataAdapterを作成します
  • DataAdapterは、 初期化 中にsample_weight_modes = [None]を使用してbroadcast_sample_weight_modes()を呼び出します
  • broadcast_sample_weight_modes()期待しているようですsample_weight_modes = Noneを受け取りますが、[None]を受け取ります
  • [None]sample_weight/class_weightとは異なる構造であると断言し、sample_weight/class_weightの構造に適合させることによってNoneに上書きし、警告を出力します

ただし、DataAdaptersample_weight_modesNoneに設定されているため、これはfit()には影響しません。

Tensorflow documentation は、sample_weightはnumpy-arrayでなければならないことを述べていることに注意してください。代わりにfit()sample_weight.tolist()で呼び出すと警告は表示されませんが、_process_numpy_inputs()が呼び出されたときにsample_weightNoneに暗黙的に上書きされます preprocessing し、1より大きい長さの入力を受け取ります。

10
Max

私はあなたの要旨を取り、TFAの代わりにTensorflow 2.0をインストールしましたが、そのような警告なしで機能しました。

これが完全なコードの Gist です。 Tensorflowをインストールするためのコードを以下に示します。

!pip install tensorflow==2.0

成功した実行のスクリーンショットを以下に示します。

enter image description here

更新:このバグはTensorflow Version 2.2.で修正されています

4