Come selezionare la prima riga di ciascun gruppo?


143

Ho un DataFrame generato come segue:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

I risultati sembrano:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Come puoi vedere, DataFrame è ordinato Hourin ordine crescente, quindi perTotalValue in ordine decrescente.

Vorrei selezionare la riga superiore di ciascun gruppo, ad es

  • dal gruppo di Ora == 0 seleziona (0, cat26,30,9)
  • dal gruppo di Ora == 1 seleziona (1, cat67,28,5)
  • dal gruppo di Ora == 2 seleziona (2, cat56,39,6)
  • e così via

Quindi l'output desiderato sarebbe:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Potrebbe essere utile poter selezionare anche le prime N righe di ciascun gruppo.

Qualsiasi aiuto è molto apprezzato.

Risposte:


231

Funzioni della finestra :

Qualcosa del genere dovrebbe fare il trucco:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Questo metodo sarà inefficace in caso di significativa inclinazione dei dati.

Aggregazione SQL semplice seguita dajoin :

In alternativa puoi unirti con un frame di dati aggregati:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Manterrà valori duplicati (se esiste più di una categoria all'ora con lo stesso valore totale). Puoi rimuoverli come segue:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Usando l'ordinazionestructs :

Trucco pulito, anche se non molto ben testato, che non richiede join o funzioni della finestra:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Con l'API DataSet (Spark 1.6+, 2.0+):

Spark 1.6 :

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 o successivo :

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Gli ultimi due metodi possono sfruttare la combinazione lato mappa e non richiedono la riproduzione casuale completa, quindi la maggior parte delle volte dovrebbe mostrare prestazioni migliori rispetto alle funzioni della finestra e ai join. Questi bastoncini possono essere utilizzati anche con Streaming strutturato incompleted modalità di output.

Non usare :

df.orderBy(...).groupBy(...).agg(first(...), ...)

Può sembrare che funzioni (specialmente nella localmodalità) ma non è affidabile (vedi SPARK-16207 , crediti a Tzach Zohar per il collegamento del problema JIRA rilevante e SPARK-30335 ).

La stessa nota si applica a

df.orderBy(...).dropDuplicates(...)

che utilizza internamente un piano di esecuzione equivalente.


3
Sembra che dalla scintilla 1.6 sia row_number () invece di rowNumber
Adam Szałucha,

Informazioni su Non utilizzare df.orderBy (...). GropBy (...). In quali circostanze possiamo fare affidamento su orderBy (...)? o se non possiamo essere sicuri che orderBy () fornirà il risultato corretto, quali alternative abbiamo?
Ignacio Alorre,

Potrei essere trascurato qualcosa, ma in generale si consiglia di evitare groupByKey , invece si dovrebbe usare ridurreByKey. Inoltre, salverai una riga.
Thomas,

3
@Thomas che evita groupBy / groupByKey è proprio quando si ha a che fare con RDD, noterai che l'API Dataset non ha nemmeno una funzione di ridurreByKey.
soote,


16

Per Spark 2.0.2 con raggruppamento per più colonne:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

8

Questo è un esattamente lo stesso di zero323 's risposta , ma in SQL modo query.

Supponendo che il frame di dati sia stato creato e registrato come

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Funzione finestra:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Aggregazione SQL semplice seguita da join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Utilizzando l'ordinamento su strutture:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSet modo e non fanno s sono le stesse in risposta originale


2

Il modello è raggruppato per chiavi => fai qualcosa per ogni gruppo, ad es. Riduci => ritorna al frame di dati

Pensavo che l'astrazione del Dataframe fosse un po 'ingombrante in questo caso, quindi ho usato la funzionalità RDD

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)

1

La soluzione seguente esegue un solo gruppo per mezzo ed estrae le righe del tuo frame di dati che contengono il valore massimo in una sola volta. Non sono necessari ulteriori join o Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}

Ma mescola tutto per primo. Non è certo un miglioramento (forse non peggiore delle funzioni della finestra, a seconda dei dati).
Alper t. Turker,

hai un primo posto di gruppo, che innescherà uno shuffle. Non è peggio della funzione finestra perché in una funzione finestra valuterà la finestra per ogni singola riga nel frame di dati.
elghoto,

1

Un buon modo per farlo con l'API dataframe è usare la logica argmax in questo modo

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+

0

Qui puoi fare così -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show

-2

Possiamo usare la funzione della finestra rank () (dove sceglieresti il ​​rank = 1) rank aggiunge semplicemente un numero per ogni riga di un gruppo (in questo caso sarebbe l'ora)

ecco un esempio. (da https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.