web-dev-qa-db-ja.com

グループ化されたSpark dataframeでブール値をカウントする方法

グループ化されたSparkデータフレームの列でtrueのレコードの数を数えたいのですが、Pythonでそれを行う方法がわかりません。たとえば、regionのデータがあります。 、salaryおよびIsUnemployed列(ブール値としてIsUnemployed)。各地域の失業者の数を確認したい。filter、次にgroupbyを実行できることはわかっているが、以下のように2つの集計を同時に生成したい。

from pyspark.sql import functions as F  
data.groupby("Region").agg(F.avg("Salary"), F.count("IsUnemployed")) 
14
MYjx

おそらく最も単純な解決策は、CASTTRUE-> 1、FALSE-> 0のCスタイル)とSUMです。

(data
    .groupby("Region")
    .agg(F.avg("Salary"), F.sum(F.col("IsUnemployed").cast("long"))))

もう少し普遍的で慣用的な解決策はCASE WHENCOUNT

(data
    .groupby("Region")
    .agg(
        F.avg("Salary"),
        F.count(F.when(F.col("IsUnemployed"), F.col("IsUnemployed")))))

しかし、ここでは明らかにやり過ぎです。

23
zero323