web-dev-qa-db-ja.com

事前に学習したニューラルネットワークをグレースケール画像で使用するにはどうすればよいですか?

グレースケール画像を含むデータセットがあり、それらの最新のCNNをトレーニングしたいと思います。事前に訓練されたモデル( here など)を微調整したいと思います。

問題は、重みを見つけることができるほとんどすべてのモデルが、RGB画像を含むImageNetデータセットでトレーニングされていることです。

入力レイヤーは形状のバッチを想定しているため、これらのモデルのいずれも使用できません(batch_size, height, width, 3)または(64, 224, 224, 3)私の場合、しかし私の画像バッチは(64, 224, 224)

これらのモデルのいずれかを使用できる方法はありますか?重みを読み込んだ後、入力レイヤーをドロップし、独自のレイヤーを追加することを考えました(トップレイヤーの場合と同様)。このアプローチは正しいですか?

10
Jcart

モデルのアーキテクチャは変更できません。重みは特定の入力構成用にトレーニングされているため、変更できません。最初のレイヤーを自分のレイヤーに置き換えると、残りのウェイトはほとんど役に立たなくなります。

-編集:プルーンが提案した詳細-
CNNは、深くなるにつれて、以前のレイヤーが抽出した下位レベルの特徴から派生した上位レベルの特徴を抽出できるように構築されています。 CNNの最初のレイヤーを削除することにより、後続のレイヤーは入力として想定されているフィーチャを受け取らないため、フィーチャの階層を破壊します。あなたの場合、2番目のレイヤーは、最初のレイヤーの特徴expectにトレーニングされています。最初のレイヤーをランダムな重みで置き換えることにより、再トレーニングが必要になるため、後続のレイヤーで行われたトレーニングは基本的に破棄されます。私は彼らが最初の訓練で学んだ知識を保持できるとは思わない。
---編集の終了---

ただし、簡単な方法があります。これにより、モデルをグレースケール画像で動作させることができます。画像をappearにしてRGBにするだけです。これを行う最も簡単な方法は、新しい次元で画像配列を3回repeatすることです。 3つのチャネルすべてで同じ画像を使用するため、モデルのパフォーマンスはRGB画像と同じである必要があります。

numpyでは、これは次のように簡単に実行できます。

print(grayscale_batch.shape)  # (64, 224, 224)
rgb_batch = np.repeat(grayscale_batch[..., np.newaxis], 3, -1)
print(rgb_batch.shape)  # (64, 224, 224, 3)

これが機能する方法は、最初に(チャネルを配置するために)新しい次元を作成し、次にこの新しい次元で既存の配列を3回繰り返します。

また、keras ' ImageDataGenerator はグレースケール画像をRGBとして読み込むことができると確信しています。

15
Djib2011

現在受け入れられている答えに従ってグレースケール画像をRGBに変換することは、この問題への1つのアプローチですが、最も効率的ではありません。モデルの最初の畳み込み層の重みを変更して、指定された目標を達成できます。変更されたモデルは、そのまま使用でき(精度は低下します)、微調整可能です。最初のレイヤーのウェイトを変更しても、他のウェイトが残りのウェイトを無効にすることはありません。

これを行うには、事前トレーニング済みの重みがロードされるコードを追加する必要があります。選択したフレームワークで、1チャンネルモデルに割り当てる前に、ネットワーク内の最初の畳み込み層の重みを取得して変更する方法を理解する必要があります。必要な変更は、入力チャネルの次元で重みテンソルを合計することです。重みテンソルの編成方法は、フレームワークごとに異なります。 PyTorchのデフォルトは[out_channels、in_channels、kernel_height、kernel_width]です。 Tensorflowでは、[kernel_height、kernel_width、in_channels、out_channels]であると信じています。

例としてPyTorchを使用して、TorchvisionのResNet50モデル( https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py )、ウェイトの形状conv1は[64、3、7、7]です。次元1で合計すると、形状のテンソルが得られます[64、1、7、7]。一番下には、TorchvisionのResNetモデルで動作するコードのスニペットが含まれています。このコードは、モデルの入力チャネルの異なる数を指定する引数(inchans)が追加されていることを前提としています。

この動作を証明するために、事前にトレーニングした重みを使用してResNet50でImageNet検証を3回実行しました。実行2と3の数値にはわずかな違いがありますが、最小限であり、微調整したら関係ないはずです。

  1. 未修正ResNet50 w/RGB画像:Prec @ 1:75.6、Prec @ 5:92.8
  2. 未修正ResNet50 w/3-chanグレースケールイメージ:Prec @ 1:64.6、Prec @ 5:86.4
  3. 変更された1チャンResNet50 w/1チャングレースケールイメージ:Prec @ 1:63.8、Prec @ 5:86.1
def _load_pretrained(model, url, inchans=3):
    state_dict = model_Zoo.load_url(url)
    if inchans == 1:
        conv1_weight = state_dict['conv1.weight']
        state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
    Elif inchans != 3:
        assert False, "Invalid number of inchans for pretrained weights"
    model.load_state_dict(state_dict)

def resnet50(pretrained=False, inchans=3):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], inchans=inchans)
    if pretrained:
        _load_pretrained(model, model_urls['resnet50'], inchans=inchans)
    return model
6
rwightman

グレースケール画像をRGB画像に変換してみませんか?

tf.image.grayscale_to_rgb(
    images,
    name=None
)
3
Hu Xixi