web-dev-qa-db-ja.com

PyTorch:カスタムデータセットにDataLoaderを使用する方法

torch.utils.data.Datasettorch.utils.data.DataLoaderを(torchvision.datasetsだけでなく)自分のデータで使用する方法は?

DataLoadersで使用する組み込みのTorchVisionDatasetsを使用して、データセットで使用する方法はありますか?

45
Sarthak

はい、可能です。自分でオブジェクトを作成するだけです。

import torch.utils.data as data_utils

train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)

ここで、featuresおよびtargetsはテンソルです。 featuresは2次元、つまり各行が1つのトレーニングサンプルを表す行列である必要があり、targetsはスカラーを予測しようとしているか、ベクトル。

お役に立てば幸いです!


EDIT:@sarthakの質問への応答

基本的にははい。タイプTensorDataのオブジェクトを作成すると、コンストラクターは、フィーチャテンソル(実際にはdata_tensorと呼ばれる)とターゲットテンソル(target_tensorと呼ばれる)の最初の次元に同じ長さ:

assert data_tensor.size(0) == target_tensor.size(0)

ただし、これらのデータを後でニューラルネットワークに送りたい場合は、注意する必要があります。畳み込み層はデータと同じように機能しますが、他のタイプの層はすべて、データが行列形式で提供されることを期待しています。したがって、このような問題に遭遇した場合、簡単な解決策は、メソッドFloatTensorを使用して4Dデータセット(何らかのテンソル、たとえばview)を行列に変換することです。 。 5000xnxnx3データセットの場合、これは次のようになります。

2d_dataset = 4d_dataset.view(5000, -1)

(値-1は、2番目の次元の長さを自動的に計算するようにPyTorchに指示します。)

46
pho7

これは、data.Datasetクラスを拡張することで簡単に行えます。 API によると、2つの関数__getitem____len__を実装するだけです。

その後、APIと@ pho7の回答に示されているように、DataLoaderでデータセットをラップできます。

ImageFolderクラスは参照だと思います。コード here を参照してください。

9
user3693922