web-dev-qa-db-ja.com

Spark DataFrame groupbyを使用するときに他の列を取得する方法は?

dataFrame groupbyを次のように使用すると:

df.groupBy(df("age")).agg(Map("id"->"count"))

「age」列と「count(id)」列のあるDataFrameのみを取得しますが、dfには「name」のような他の列が多数あります。

全体として、MySQLのように結果を取得したいのですが、

「年齢別にdfグループからname、age、count(id)を選択」

Sparkでgroupbyを使用する場合はどうすればよいですか?

27
Psychevic

簡単に言えば、集計結果を元のテーブルに結合する必要があります。 Spark SQLは、集計クエリで追加の列を許可しないほとんどの主要なデータベース(PostgreSQL、Oracle、MS SQL Server)と同じpre-SQL:1999の規則に従います。

カウント結果などの集計は十分に定義されておらず、このタイプのクエリをサポートするシステムでは動作が異なる傾向があるため、firstlastなどの任意の集計を使用して追加の列を含めることができます。

場合によっては、aggを使用してselectをウィンドウ関数とそれに続くwhereに置き換えることができますが、コンテキストによっては非常に高価になる場合があります。

29
zero323

GroupByを実行した後にすべての列を取得する1つの方法は、結合関数を使用することです。

feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)

data_joinedには、カウント値を含むすべての列が含まれるようになりました。

8
Swetha Kannan

この解決策が役立つかもしれません。

from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window

    name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
                 (105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
                 (109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]

age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])

name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()

出力:

+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26|    1|
|113| tuv| 29|    1|
|110| tuv| 27|    1|
|106| klm| 19|    1|
|103| efg| 22|    1|
|104| ghi| 21|    1|
|105| ijk| 20|    1|
|112| rst| 28|    1|
|101| abc| 24|    2|
|102| cde| 24|    2|
|107| mno| 18|    3|
|111| pqr| 18|    3|
|108| pqr| 18|    3|
+---+----+---+-----+
1

ここで私が出会った例 spark-workshop

_val populationDF = spark.read
                .option("infer-schema", "true")
                .option("header", "true")
                .format("csv").load("file:///databricks/driver/population.csv")
                .select('name, regexp_replace(col("population"), "\\s", "").cast("integer").as("population"))
_

val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

他の列を取得するには、元のDFと集約された列

_populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()
_
0
Mohamed Hosni

集計関数は、グループ内の指定された列の行の値を減らします。他の行の値を保持する場合は、各値の元になる行を指定するリダクションロジックを実装する必要があります。たとえば、年齢の最大値を持つ最初の行のすべての値を保持します。このために、UDAF(ユーザー定義の集計関数)を使用して、グループ内の行を減らすことができます。

import org.Apache.spark.sql._
import org.Apache.spark.sql.functions._


object AggregateKeepingRowJob {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      (1L, "Moe",  "Slap",  2.0, 18),
      (2L, "Larry",  "Spank",  3.0, 15),
      (3L, "Curly",  "Twist", 5.0, 15),
      (4L, "Laurel", "Whimper", 3.0, 15),
      (5L, "Hardy", "Laugh", 6.0, 15),
      (6L, "Charley",  "Ignore",   5.0, 5)
    ).toDF("id", "name", "requisite", "money", "age")

    rawDf.show(false)
    rawDf.printSchema

    val maxAgeUdaf = new KeepRowWithMaxAge

    val aggDf = rawDf
      .groupBy("age")
      .agg(
        count("id"),
        max(col("money")),
        maxAgeUdaf(
          col("id"),
          col("name"),
          col("requisite"),
          col("money"),
          col("age")).as("KeepRowWithMaxAge")
      )

    aggDf.printSchema
    aggDf.show(false)

  }


}

UDAF:

import org.Apache.spark.sql.Row
import org.Apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.Apache.spark.sql.types._

class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.Apache.spark.sql.types.StructType =
  StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
  StructField("store", StringType) ::
  StructField("prod", StringType) ::
  StructField("amt", DoubleType) ::
  StructField("units", IntegerType) :: Nil
)


// This is the output type of your aggregation function.
override def dataType: DataType =
  StructType((Array(
    StructField("store", StringType),
    StructField("prod", StringType),
    StructField("amt", DoubleType),
    StructField("units", IntegerType)
  )))

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = ""
  buffer(1) = ""
  buffer(2) = 0.0
  buffer(3) = 0
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

  val amt = buffer.getAs[Double](2)
  val candidateAmt = input.getAs[Double](2)

  amt match {
    case a if a < candidateAmt =>
      buffer(0) = input.getAs[String](0)
      buffer(1) = input.getAs[String](1)
      buffer(2) = input.getAs[Double](2)
      buffer(3) = input.getAs[Int](3)
    case _ =>
  }
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer2.getAs[String](0)
  buffer1(1) = buffer2.getAs[String](1)
  buffer1(2) = buffer2.getAs[Double](2)
  buffer1(3) = buffer2.getAs[Int](3)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer
}
}
0
Rubber Duck