tracciare colori diversi per diversi livelli categoriali utilizzando matplotlib


102

Ho questo frame di dati diamondsche è composto da variabili come (carat, price, color), e voglio disegnare un grafico a dispersione di priceper caratper ogni color, il che significa diverso colorha colore diverso nella trama.

Questo è facile Rcon ggplot:

ggplot(aes(x=carat, y=price, color=color),  #by setting color=color, ggplot automatically draw in different colors
       data=diamonds) + geom_point(stat='summary', fun.y=median)

inserisci qui la descrizione dell'immagine

Mi chiedo come potrebbe essere fatto in Python usando matplotlib?

PS:

Conosco pacchetti di plottaggio ausiliari, come seaborne ggplot for python, e non li preferisco, voglio solo scoprire se è possibile fare il lavoro usando matplotlibda solo,; P


1
Sarebbe davvero bello avere qualcosa di simile integrato in matplotlib, ma sembra che non sarà facile. Discussione qui: github.com/matplotlib/matplotlib/issues/6214
nought101

Risposte:


156

Puoi passare plt.scatterun cargomento che ti permetterà di selezionare i colori. Il seguente codice definisce un colorsdizionario per mappare i colori di diamanti ai colori stampa.

import matplotlib.pyplot as plt
import pandas as pd

carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]

df = pd.DataFrame(dict(carat=carat, price=price, color=color))

fig, ax = plt.subplots()

colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}

ax.scatter(df['carat'], df['price'], c=df['color'].apply(lambda x: colors[x]))

plt.show()

df['color'].apply(lambda x: colors[x]) mappa efficacemente i colori dal "diamante" al "disegno".

(Perdonami per non aver inserito un'altra immagine di esempio, penso che 2 sia sufficiente: P)

Con seaborn

Puoi usare seabornche è un wrapper matplotlibche lo rende più carino per impostazione predefinita (piuttosto basato sull'opinione, lo so: P) ma aggiunge anche alcune funzioni di plottaggio.

Per questo potresti usare seaborn.lmplotcon fit_reg=False(che gli impedisce di fare automaticamente qualche regressione).

Il codice seguente utilizza un set di dati di esempio. Selezionando hue='color'si dice a seaborn di suddividere il dataframe in base ai colori e quindi di tracciare ciascuno di essi.

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd

carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]

df = pd.DataFrame(dict(carat=carat, price=price, color=color))

sns.lmplot('carat', 'price', data=df, hue='color', fit_reg=False)

plt.show()

inserisci qui la descrizione dell'immagine

Senza seabornusarepandas.groupby

Se non vuoi usare seaborn, puoi usare pandas.groupbyper ottenere i colori da soli e poi stamparli usando solo matplotlib, ma dovrai assegnare manualmente i colori mentre procedi, ho aggiunto un esempio di seguito:

fig, ax = plt.subplots()

colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}

grouped = df.groupby('color')
for key, group in grouped:
    group.plot(ax=ax, kind='scatter', x='carat', y='price', label=key, color=colors[key])

plt.show()

Questo codice presuppone lo stesso DataFrame di sopra e quindi lo raggruppa in base a color. Quindi itera su questi gruppi, tracciando per ciascuno di essi. Per selezionare un colore ho creato un colorsdizionario che può mappare il colore del diamante (ad esempio D) a un colore reale (ad esempio red).

inserisci qui la descrizione dell'immagine


Grazie, ma voglio solo scoprire come fare il lavoro con matplotlib da solo.
avocado

Sì, tramite groupbypotrei farlo, quindi c'è una tale funzionalità matplotlibche può disegnare automaticamente per diversi livelli di un categoriale usando colori diversi, giusto?
avocado

@loganecolss Ok vedo :) L'ho modificato di nuovo e ho aggiunto un esempio molto semplice che utilizza un dizionario per mappare i colori, in modo simile groupbyall'esempio.
Ffisegydd

1
@Ffisegydd Usando il primo metodo ax.scatter, come aggiungeresti le leggende? Sto cercando di utilizzare label=df['color']e quindi plt.legend()senza successo.
ahoosh

1
Sarebbe meglio passare ax.scatter(df['carat'], df['price'], c=df['color'].apply(lambda x: colors[x]))aax.scatter(df['carat'], df['price'], c=df['color'].map(colors)
Dawei

33

Ecco una soluzione succinta e generica per utilizzare una tavolozza di colori marini.

Per prima cosa trova una tavolozza di colori che ti piace e opzionalmente visualizzala:

sns.palplot(sns.color_palette("Set2", 8))

Quindi puoi usarlo matplotlibfacendo questo:

# Unique category labels: 'D', 'F', 'G', ...
color_labels = df['color'].unique()

# List of RGB triplets
rgb_values = sns.color_palette("Set2", 8)

# Map label to RGB
color_map = dict(zip(color_labels, rgb_values))

# Finally use the mapped values
plt.scatter(df['carat'], df['price'], c=df['color'].map(color_map))

2
Mi piace il tuo approccio. Dato l'esempio sopra, puoi ovviamente mappare i valori anche a nomi di colori semplici come questo: 1) definire i colori colors = {'D': 'red', 'E': 'blue', 'F': 'green ',' G ':' black '} 2) mappale come hai fatto tu: ax.scatter (df [' carat '], df [' price '], c = df [' color ']. Map (colors))
Stefan

1
Come aggiungeresti un'etichetta in base al colore in questo caso, però?
François Leblanc

2
Per aggiungere un po 'più di astrazione, puoi sostituire 8in sns.color_palette("Set2", 8)con len(color_labels).
Swier

È fantastico, ma dovrebbe essere fatto automaticamente da Seaborn. Dover usare una mappa per variabili categoriali ogni volta che vuoi tracciare qualcosa velocemente è incredibilmente difficile. Per non parlare dell'idea idiota di eliminare la possibilità di visualizzare le statistiche sulla trama. Seaborn, purtroppo, sta declinando come pacchetto a causa di questi motivi
insegui il

8

Avevo la stessa domanda e ho passato tutto il giorno a provare diversi pacchetti.

Inizialmente avevo usato matlibplot: e non ero soddisfatto di nessuna delle categorie di mappatura a colori predefiniti; o raggruppando / aggregando quindi iterando attraverso i gruppi (e dovendo ancora mappare i colori). Ho solo sentito che era una cattiva implementazione del pacchetto.

Seaborn non funzionerebbe sul mio caso e Altair funziona SOLO all'interno di un taccuino Jupyter.

La soluzione migliore per me era PlotNine, che "è un'implementazione di una grammatica grafica in Python e basata su ggplot2".

Di seguito è riportato il codice di plotnine per replicare il tuo esempio R in Python:

from plotnine import *
from plotnine.data import diamonds

g = ggplot(diamonds, aes(x='carat', y='price', color='color')) + geom_point(stat='summary')
print(g)

esempio di plottonine diamanti

Così pulito e semplice :)


Domanda posta per matplotlib
Chuck

6

Utilizzando Altair .

from altair import *
import pandas as pd

df = datasets.load_dataset('iris')
Chart(df).mark_point().encode(x='petalLength',y='sepalLength', color='species')

inserisci qui la descrizione dell'immagine


Domanda posta per matplotlib
Chuck

5

Ecco una combinazione di pennarelli e colori da una mappa colori qualitativa in matplotlib:

import itertools
import numpy as np
from matplotlib import markers
import matplotlib.pyplot as plt

m_styles = markers.MarkerStyle.markers
N = 60
colormap = plt.cm.Dark2.colors  # Qualitative colormap
for i, (marker, color) in zip(range(N), itertools.product(m_styles, colormap)):
    plt.scatter(*np.random.random(2), color=color, marker=marker, label=i)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=4);

inserisci qui la descrizione dell'immagine


In mpl.cm.Dark2.colors- mplnon sembra essere definito nel codice e Dark2non ha attributi colors.
Shovalt

@ Shovalt Grazie per la recensione. Avrei dovuto importare matplotlibcome mpl, ho corretto il mio codice usando pltche contiene anche cm. Almeno nella matplotlibversione che sto usando 2.0.0 Dark2ha un attributocolors
Pablo Reyes

1
In ritardo, ma se non hai l'attributo colors: iter (plt.cm.Dark2 (np.linspace (0,1, N)))
Geoff Lentsch

3

Con df.plot ()

Normalmente quando traccio rapidamente un DataFrame, uso pd.DataFrame.plot(). Questo prende l'indice come valore x, il valore come valore y e traccia ogni colonna separatamente con un colore diverso. Un DataFrame in questa forma può essere ottenuto utilizzando set_indexe unstack.

import matplotlib.pyplot as plt
import pandas as pd

carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]

df = pd.DataFrame(dict(carat=carat, price=price, color=color))

df.set_index(['color', 'carat']).unstack('color')['price'].plot(style='o')
plt.ylabel('price')

tracciare

Con questo metodo non è necessario specificare manualmente i colori.

Questa procedura può avere più senso per altre serie di dati. Nel mio caso ho i dati della serie temporale, quindi il MultiIndex è composto da datetime e categorie. È anche possibile utilizzare questo approccio per colorare più di una colonna, ma la legenda sta diventando un pasticcio.


0

Di solito lo faccio usando Seaborn che è costruito sopra matplotlib

import seaborn as sns
iris = sns.load_dataset('iris')
sns.scatterplot(x='sepal_length', y='sepal_width',
              hue='species', data=iris); 
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.