web-dev-qa-db-ja.com

画像チャネルPyTorch全体の平均と標準偏差を見つける

たとえば、寸法(B x C x W x H)のテンソル形式の画像のバッチがあるとします。Bはバッチサイズ、Cは画像のチャネル数、WとHは幅と高さです。それぞれ画像。 transforms.Normalize()関数を使用して、C画像チャネル全体のデータセットの平均と標準偏差に関して画像を正規化しようとしています、1 x Cの形式のテンソルが必要なことを意味します。これを行う簡単な方法はありますか?

私はtorch.view(C, -1).mean(1)torch.view(C, -1).std(1)を試しましたが、エラーが発生しました:

_view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
_

編集

PyTorchでview()がどのように機能するかを調べた後、私のアプローチが機能しない理由を理解しました。ただし、チャネルごとの平均と標準偏差を取得する方法はまだわかりません。

6
ch1maera

標準偏差ではなく、分散が追加されることに注意してください。詳細な説明はこちら: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

変更されたコードは次のとおりです。

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)
3
debadeepta