web-dev-qa-db-ja.com

メソッド「train_test_split」からのパラメーター「stratify」(scikit Learn)

パッケージscikit Learnのtrain_test_splitを使用しようとしていますが、パラメータstratifyに問題があります。コードは次のとおりです。

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

ただし、次の問題が引き続き発生します。

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

誰かが何が起こっているか考えていますか?以下は関数のドキュメントです。

[...]

stratify:array-likeまたはNone(デフォルトはNone)

Noneでない場合、データは層状に分割され、これをラベル配列として使用します。

バージョン0.17の新機能:stratifysplitting

[...]

58
Daneel Olivaw

Scikit-Learnは、「stratify」という引数を認識していないことを伝えているだけで、誤って使用しているわけではありません。これは、引用したドキュメントに示されているように、パラメーターがバージョン0.17で追加されたためです。

したがって、Scikit-Learnを更新するだけです。

44
Borja

このstratifyパラメーターは、生成されるサンプルの値の割合がパラメーターstratifyに提供される値の割合と同じになるように分割します。

たとえば、変数yが値0および1を持つバイナリカテゴリ変数であり、ゼロの25%と1の75%がある場合、stratify=yはランダム分割には、25%の0と75%の1があります。

189
Fazzolini

Google経由でここに来る私の将来の自己のために:

train_test_splitmodel_selectionにあるため、次のとおりです。

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

それを使用する方法です。 random_stateの設定は、再現性のために望ましいです。

36
Martin Thoma

このコンテキストでは、階層化とは、train_test_splitメソッドが、入力データセットと同じ割合のクラスラベルを持つトレーニングおよびテストサブセットを返すことを意味します。

8
X. Wang

このコードを実行してみてください。「うまくいく」だけです。

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])
3