web-dev-qa-db-ja.com

sklearnのStratifiedKFoldとStratifiedShuffleSplitの違い

タイトルから、私はの違いは何ですか

StratifiedKFold パラメーター付きshuffle = True

StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

そして

StratifiedShuffleSplit

StratifiedShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=0)

stratifiedShuffleSplitを使用する利点は何ですか

32
gabboshow

KFoldsでは、シャッフルを使用しても、各テストセットが重複しないようにする必要があります。 KFoldsとシャッフルでは、データは開始時に一度シャッフルされ、その後、希望する分割数に分割されます。テストデータは常に分割の1つで、残りはトレインデータです。

ShuffleSplitでは、データは毎回シャッフルされ、その後分割されます。これは、テストセットが分割間で重複する可能性があることを意味します。

違いの例については、このブロックを参照してください。 ShuffleSplitのテストセットの要素の重複に注意してください。

splits = 5

tx = range(10)
ty = [0] * 5 + [1] * 5

from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn import datasets

kfold = StratifiedKFold(n_splits=splits, shuffle=True, random_state=42)
shufflesplit = StratifiedShuffleSplit(n_splits=splits, random_state=42, test_size=2)

print("KFold")
for train_index, test_index in kfold.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

print("Shuffle Split")
for train_index, test_index in shufflesplit.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

出力:

KFold
TRAIN: [0 2 3 4 5 6 7 9] TEST: [1 8]
TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]
TRAIN: [0 1 3 4 5 6 8 9] TEST: [2 7]
TRAIN: [1 2 3 4 6 7 8 9] TEST: [0 5]
TRAIN: [0 1 2 4 5 6 7 8] TEST: [3 9]
Shuffle Split
TRAIN: [8 4 1 0 6 5 7 2] TEST: [3 9]
TRAIN: [7 0 3 9 4 5 1 6] TEST: [8 2]
TRAIN: [1 2 5 6 4 8 9 0] TEST: [3 7]
TRAIN: [4 6 7 8 3 5 1 2] TEST: [9 0]
TRAIN: [7 2 6 5 4 3 0 9] TEST: [1 8]

それらをいつ使用するかについては、クロス検証にKFoldsを使用する傾向があり、電車/テストセットの分割には2分割でShuffleSplitを使用します。しかし、私は両方のための他のユースケースがあると確信しています。

43
Ken Syme

@Ken Symeにはすでに非常に良い答えがあります。何かを追加したいだけです。

  • StratifiedKFoldKFoldのバリエーションです。最初に、StratifiedKFoldはデータをシャッフルし、その後、データをn_splits部分に分割して完了します。これで、各パーツがテストセットとして使用されます。 のみで、常にデータを1回シャッフルすることに注意してください。

shuffle = Trueを使用すると、データはrandom_stateによってシャッフルされます。それ以外の場合、データはnp.random(デフォルト)によってシャッフルされます。たとえば、n_splits = 4の場合、データにはy(従属変数)の3つのクラス(ラベル)があります。 4つのテストセットは、重複することなくすべてのデータをカバーします。

enter image description here

  • 一方、StratifiedShuffleSplitShuffleSplitのバリエーションです。最初に、StratifiedShuffleSplitはデータをシャッフルし、次にデータをn_splits部分に分割します。ただし、まだ完了していません。このステップの後、StratifiedShuffleSplitはテストセットとして使用するパーツを1つ選択します。次に、同じプロセスをn_splits - 1他の回数繰り返して、n_splits - 1他のテストセットを取得します。同じデータで下の図を見てください。ただし、今回は、4つのテストセットがすべてのデータをカバーしているわけではありません。つまり、テストセット間で重複があります。

enter image description here

そのため、ここでの違いは、StratifiedKFoldが一度シャッフルして分割するだけなので、テストセットはオーバーラップしません、一方StratifiedShuffleSplit分割する前に毎回シャッフルし、n_splits回分割します。テストセットはオーバーラップできます

  • :2つの方法は「層化された折り畳み」を使用します(「層化された」が両方の名前に表示される理由)。つまり、各パーツは、元のデータと同じ割合の各クラス(ラベル)のサンプルの割合を保持します。詳しくは cross_validation documents をご覧ください。
26
Catbuilts