Converti la colonna Spark DataFrame in un elenco Python


104

Lavoro su un dataframe con due colonne, mvv e count.

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

vorrei ottenere due elenchi contenenti valori mvv e valore di conteggio. Qualcosa di simile a

mvv = [1,2,3,4]
count = [5,9,3,1]

Quindi, ho provato il seguente codice: La prima riga dovrebbe restituire un elenco di righe in Python. Volevo vedere il primo valore:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

Ma ricevo un messaggio di errore con la seconda riga:

AttributeError: getInt


Come di Spark 2.3, questo codice è il più veloce e meno probabilità di causare eccezioni OutOfMemory: list(df.select('mvv').toPandas()['mvv']). Arrow è stato integrato in PySpark che ha accelerato toPandasnotevolmente. Non utilizzare gli altri approcci se utilizzi Spark 2.3+. Vedi la mia risposta per ulteriori dettagli sul benchmarking.
Poteri

Risposte:


141

Vedi, perché in questo modo che stai facendo non funziona. Innanzitutto, stai cercando di ottenere un numero intero da un tipo di riga , l'output della tua raccolta è come questo:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

Se prendi qualcosa del genere:

>>> firstvalue = mvv_list[0].mvv
Out: 1

Otterrai il mvvvalore. Se vuoi tutte le informazioni dell'array puoi prendere qualcosa del genere:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

Ma se provi lo stesso per l'altra colonna, ottieni:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

Questo accade perché countè un metodo integrato. E la colonna ha lo stesso nome di count. Una soluzione alternativa per farlo è modificare il nome della colonna countin _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

Ma questa soluzione alternativa non è necessaria, poiché puoi accedere alla colonna utilizzando la sintassi del dizionario:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

E finalmente funzionerà!


funziona alla grande per la prima colonna, ma non funziona per il conteggio delle colonne penso a causa di (la funzione count of spark)
a.moussa

Puoi aggiungere cosa stai facendo con il conteggio? Aggiungi qui nei commenti.
Thiago Baldim

grazie per la tua risposta Quindi questa riga funziona mvv_list = [int (i.mvv) for i in mvv_count.select ('mvv'). collect ()] ma non questo count_list = [int (i.count) for i in mvv_count .select ('count'). collect ()] restituisce sintassi non valida
a.moussa

Non è necessario aggiungere questo select('count')uso in questo modo: count_list = [int(i.count) for i in mvv_list.collect()]aggiungerò l'esempio alla risposta.
Thiago Baldim

1
@ a.moussa [i.['count'] for i in mvv_list.collect()]lavora per rendere esplicito l'uso della colonna denominata 'count' e non la countfunzione
user989762

103

Seguendo una riga si ottiene l'elenco desiderato.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

3
Dal punto di vista delle prestazioni questa soluzione è molto più veloce della tua soluzione mvv_list = [int (i.mvv) for i in mvv_count.select ('mvv'). Collect ()]
Chanaka Fernando

Questa è di gran lunga la migliore soluzione che ho visto. Grazie.
hui chen

22

Questo ti darà tutti gli elementi come un elenco.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

1
Questa è la soluzione più veloce ed efficiente per Spark 2.3+. Vedi i risultati del benchmarking nella mia risposta.
Poteri

16

Il codice seguente ti aiuterà

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

3
Questa dovrebbe essere la risposta accettata. il motivo è che rimani in un contesto scintilla durante tutto il processo e poi raccogli alla fine invece di uscire prima dal contesto scintilla, il che potrebbe causare una raccolta più ampia a seconda di ciò che stai facendo.
AntiPawn79

15

Sui miei dati ho ottenuto questi benchmark:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0,52 sec

>>> [row[col] for row in data.collect()]

0,271 sec

>>> list(data.select(col).toPandas()[col])

0.427 sec

Il risultato è lo stesso


1
Se lo usi al toLocalIteratorposto di collectesso dovrebbe anche essere più efficiente la memoria[row[col] for row in data.toLocalIterator()]
oglop

6

Se ricevi il seguente errore:

AttributeError: l'oggetto "list" non ha l'attributo "collect"

Questo codice risolverà i tuoi problemi:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

Ho ricevuto anche quell'errore e questa soluzione ha risolto il problema. Ma perché ho ricevuto l'errore? (Molti altri sembrano non
capirlo

2

Ho eseguito un'analisi di benchmarking e list(mvv_count_df.select('mvv').toPandas()['mvv']) è il metodo più veloce. Sono molto sorpreso.

Ho eseguito i diversi approcci su 100 mila / 100 milioni di set di dati di righe utilizzando un cluster i3.xlarge a 5 nodi (ogni nodo ha 30,5 GB di RAM e 4 core) con Spark 2.4.5. I dati sono stati distribuiti uniformemente su 20 file Parquet compressi scattanti con una singola colonna.

Ecco i risultati del benchmarking (tempi di esecuzione in secondi):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Regole d'oro da seguire quando si raccolgono dati sul nodo driver:

  • Prova a risolvere il problema con altri approcci. La raccolta dei dati nel nodo driver è costosa, non sfrutta la potenza del cluster Spark e dovrebbe essere evitata quando possibile.
  • Raccogli il minor numero di righe possibile. Aggrega, deduplica, filtra e elimina le colonne prima di raccogliere i dati. Invia il minor numero di dati possibile al nodo del driver.

toPandas è stato notevolmente migliorato in Spark 2.3 . Probabilmente non è l'approccio migliore se stai utilizzando una versione Spark precedente alla 2.3.

Vedi qui per maggiori dettagli / risultati di benchmarking.


2

Una possibile soluzione sta utilizzando la collect_list()funzione da pyspark.sql.functions. Questo aggregherà tutti i valori delle colonne in un array pyspark che viene convertito in un elenco python quando raccolto:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.