Etichette in linea in Matplotlib


100

In Matplotlib, non è troppo difficile creare una legenda ( example_legend(), sotto), ma penso che sia uno stile migliore per mettere etichette proprio sulle curve che vengono tracciate (come in example_inline(), sotto). Questo può essere molto complicato, perché devo specificare le coordinate a mano e, se riformatto la trama, probabilmente devo riposizionare le etichette. Esiste un modo per generare automaticamente etichette sulle curve in Matplotlib? Punti bonus per essere in grado di orientare il testo con un angolo corrispondente all'angolo della curva.

import numpy as np
import matplotlib.pyplot as plt

def example_legend():
    plt.clf()
    x = np.linspace(0, 1, 101)
    y1 = np.sin(x * np.pi / 2)
    y2 = np.cos(x * np.pi / 2)
    plt.plot(x, y1, label='sin')
    plt.plot(x, y2, label='cos')
    plt.legend()

Figura con legenda

def example_inline():
    plt.clf()
    x = np.linspace(0, 1, 101)
    y1 = np.sin(x * np.pi / 2)
    y2 = np.cos(x * np.pi / 2)
    plt.plot(x, y1, label='sin')
    plt.plot(x, y2, label='cos')
    plt.text(0.08, 0.2, 'sin')
    plt.text(0.9, 0.2, 'cos')

Figura con etichette in linea

Risposte:


28

Bella domanda, tempo fa ho sperimentato un po 'con questo, ma non l'ho usato molto perché non è ancora a prova di proiettile. Ho diviso l'area del tracciato in una griglia 32x32 e ho calcolato un "campo potenziale" per la migliore posizione di un'etichetta per ogni riga secondo le seguenti regole:

  • lo spazio bianco è un buon posto per un'etichetta
  • L'etichetta dovrebbe essere vicino alla riga corrispondente
  • L'etichetta dovrebbe essere lontana dalle altre righe

Il codice era qualcosa del genere:

import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage


def my_legend(axis = None):

    if axis == None:
        axis = plt.gca()

    N = 32
    Nlines = len(axis.lines)
    print Nlines

    xmin, xmax = axis.get_xlim()
    ymin, ymax = axis.get_ylim()

    # the 'point of presence' matrix
    pop = np.zeros((Nlines, N, N), dtype=np.float)    

    for l in range(Nlines):
        # get xy data and scale it to the NxN squares
        xy = axis.lines[l].get_xydata()
        xy = (xy - [xmin,ymin]) / ([xmax-xmin, ymax-ymin]) * N
        xy = xy.astype(np.int32)
        # mask stuff outside plot        
        mask = (xy[:,0] >= 0) & (xy[:,0] < N) & (xy[:,1] >= 0) & (xy[:,1] < N)
        xy = xy[mask]
        # add to pop
        for p in xy:
            pop[l][tuple(p)] = 1.0

    # find whitespace, nice place for labels
    ws = 1.0 - (np.sum(pop, axis=0) > 0) * 1.0 
    # don't use the borders
    ws[:,0]   = 0
    ws[:,N-1] = 0
    ws[0,:]   = 0  
    ws[N-1,:] = 0  

    # blur the pop's
    for l in range(Nlines):
        pop[l] = ndimage.gaussian_filter(pop[l], sigma=N/5)

    for l in range(Nlines):
        # positive weights for current line, negative weight for others....
        w = -0.3 * np.ones(Nlines, dtype=np.float)
        w[l] = 0.5

        # calculate a field         
        p = ws + np.sum(w[:, np.newaxis, np.newaxis] * pop, axis=0)
        plt.figure()
        plt.imshow(p, interpolation='nearest')
        plt.title(axis.lines[l].get_label())

        pos = np.argmax(p)  # note, argmax flattens the array first 
        best_x, best_y =  (pos / N, pos % N) 
        x = xmin + (xmax-xmin) * best_x / N       
        y = ymin + (ymax-ymin) * best_y / N       


        axis.text(x, y, axis.lines[l].get_label(), 
                  horizontalalignment='center',
                  verticalalignment='center')


plt.close('all')

x = np.linspace(0, 1, 101)
y1 = np.sin(x * np.pi / 2)
y2 = np.cos(x * np.pi / 2)
y3 = x * x
plt.plot(x, y1, 'b', label='blue')
plt.plot(x, y2, 'r', label='red')
plt.plot(x, y3, 'g', label='green')
my_legend()
plt.show()

E la trama risultante: inserisci qui la descrizione dell'immagine


Molto bella. Tuttavia, ho un esempio che non funziona completamente: plt.plot(x2, 3*x2**2, label="3x*x"); plt.plot(x2, 2*x2**2, label="2x*x"); plt.plot(x2, 0.5*x2**2, label="0.5x*x"); plt.plot(x2, -1*x2**2, label="-x*x"); plt.plot(x2, -2.5*x2**2, label="-2.5*x*x"); my_legend();questo mette una delle etichette nell'angolo in alto a sinistra. Qualche idea su come risolvere questo problema? Sembra che il problema potrebbe essere che le linee sono troppo vicine tra loro.
egpbos

Scusa, dimenticavo x2 = np.linspace(0,0.5,100).
egpbos

C'è un modo per usarlo senza scipy? Sul mio sistema attuale è una seccatura da installare.
AnnanFay

Questo non funziona per me con Python 3.6.4, Matplotlib 2.1.2 e Scipy 1.0.0. Dopo aver aggiornato il printcomando, viene eseguito e crea 4 grafici, 3 dei quali sembrano essere pixelati senza senso (probabilmente qualcosa a che fare con il 32x32) e il quarto con etichette in posti dispari.
Y Davis,

84

Aggiornamento: l' utente cphyc ha gentilmente creato un repository Github per il codice in questa risposta (vedi qui ) e ha raggruppato il codice in un pacchetto che può essere installato utilizzando pip install matplotlib-label-lines.


Bella immagine:

etichettatura semiautomatica della trama

In matplotlibè abbastanza facile etichettare i grafici di contorno (automaticamente o posizionando manualmente le etichette con i clic del mouse). Non sembra (ancora) esserci alcuna capacità equivalente per etichettare le serie di dati in questo modo! Potrebbe esserci qualche ragione semantica per non includere questa caratteristica che mi manca.

Indipendentemente da ciò, ho scritto il seguente modulo che accetta qualsiasi autorizzazione per l'etichettatura semiautomatica del diagramma. Richiede solo numpye un paio di funzioni dalla mathlibreria standard .

Descrizione

Il comportamento predefinito della labelLinesfunzione è di distanziare le etichette in modo uniforme lungo l' xasse (posizionando automaticamente il valore corretto yovviamente). Se vuoi puoi semplicemente passare un array delle coordinate x di ciascuna delle etichette. Puoi anche modificare la posizione di un'etichetta (come mostrato nella trama in basso a destra) e distanziare il resto in modo uniforme, se lo desideri.

Inoltre, la label_linesfunzione non tiene conto delle righe a cui non è stata assegnata un'etichetta nel plotcomando (o più precisamente se l'etichetta contiene '_line').

Gli argomenti della parola chiave passati labelLineso labelLinevengono passati alla textchiamata di funzione (alcuni argomenti della parola chiave vengono impostati se il codice chiamante sceglie di non specificare).

Problemi

  • I rettangoli di delimitazione delle annotazioni a volte interferiscono in modo indesiderato con altre curve. Come mostrato dalle annotazioni 1e 10nel grafico in alto a sinistra. Non sono nemmeno sicuro che questo possa essere evitato.
  • Sarebbe bello yinvece specificare una posizione a volte.
  • È ancora un processo iterativo per ottenere le annotazioni nella posizione corretta
  • Funziona solo quando i xvalori -axis sono floats

Trabocchetti

  • Per impostazione predefinita, la labelLinesfunzione presuppone che tutte le serie di dati coprano l'intervallo specificato dai limiti dell'asse. Dai un'occhiata alla curva blu nella trama in alto a sinistra della bella immagine. Se fossero disponibili solo dati per l' xintervallo 0.5, 1non potremmo posizionare un'etichetta nella posizione desiderata (che è leggermente inferiore a 0.2). Vedi questa domanda per un esempio particolarmente sgradevole. Al momento, il codice non identifica in modo intelligente questo scenario e riorganizza le etichette, tuttavia esiste una soluzione alternativa ragionevole. La funzione labelLines accetta l' xvalsargomento; un elenco di xvalori specificati dall'utente invece della distribuzione lineare predefinita sulla larghezza. Quindi l'utente può decidere qualex-valori da utilizzare per il posizionamento dell'etichetta di ciascuna serie di dati.

Inoltre, credo che questa sia la prima risposta per completare l' obiettivo bonus di allineare le etichette con la curva su cui si trovano. :)

label_lines.py:

from math import atan2,degrees
import numpy as np

#Label line with line2D label data
def labelLine(line,x,label=None,align=True,**kwargs):

    ax = line.axes
    xdata = line.get_xdata()
    ydata = line.get_ydata()

    if (x < xdata[0]) or (x > xdata[-1]):
        print('x label location is outside data range!')
        return

    #Find corresponding y co-ordinate and angle of the line
    ip = 1
    for i in range(len(xdata)):
        if x < xdata[i]:
            ip = i
            break

    y = ydata[ip-1] + (ydata[ip]-ydata[ip-1])*(x-xdata[ip-1])/(xdata[ip]-xdata[ip-1])

    if not label:
        label = line.get_label()

    if align:
        #Compute the slope
        dx = xdata[ip] - xdata[ip-1]
        dy = ydata[ip] - ydata[ip-1]
        ang = degrees(atan2(dy,dx))

        #Transform to screen co-ordinates
        pt = np.array([x,y]).reshape((1,2))
        trans_angle = ax.transData.transform_angles(np.array((ang,)),pt)[0]

    else:
        trans_angle = 0

    #Set a bunch of keyword arguments
    if 'color' not in kwargs:
        kwargs['color'] = line.get_color()

    if ('horizontalalignment' not in kwargs) and ('ha' not in kwargs):
        kwargs['ha'] = 'center'

    if ('verticalalignment' not in kwargs) and ('va' not in kwargs):
        kwargs['va'] = 'center'

    if 'backgroundcolor' not in kwargs:
        kwargs['backgroundcolor'] = ax.get_facecolor()

    if 'clip_on' not in kwargs:
        kwargs['clip_on'] = True

    if 'zorder' not in kwargs:
        kwargs['zorder'] = 2.5

    ax.text(x,y,label,rotation=trans_angle,**kwargs)

def labelLines(lines,align=True,xvals=None,**kwargs):

    ax = lines[0].axes
    labLines = []
    labels = []

    #Take only the lines which have labels other than the default ones
    for line in lines:
        label = line.get_label()
        if "_line" not in label:
            labLines.append(line)
            labels.append(label)

    if xvals is None:
        xmin,xmax = ax.get_xlim()
        xvals = np.linspace(xmin,xmax,len(labLines)+2)[1:-1]

    for line,x,label in zip(labLines,xvals,labels):
        labelLine(line,x,label,align,**kwargs)

Prova il codice per generare la bella immagine sopra:

from matplotlib import pyplot as plt
from scipy.stats import loglaplace,chi2

from labellines import *

X = np.linspace(0,1,500)
A = [1,2,5,10,20]
funcs = [np.arctan,np.sin,loglaplace(4).pdf,chi2(5).pdf]

plt.subplot(221)
for a in A:
    plt.plot(X,np.arctan(a*X),label=str(a))

labelLines(plt.gca().get_lines(),zorder=2.5)

plt.subplot(222)
for a in A:
    plt.plot(X,np.sin(a*X),label=str(a))

labelLines(plt.gca().get_lines(),align=False,fontsize=14)

plt.subplot(223)
for a in A:
    plt.plot(X,loglaplace(4).pdf(a*X),label=str(a))

xvals = [0.8,0.55,0.22,0.104,0.045]
labelLines(plt.gca().get_lines(),align=False,xvals=xvals,color='k')

plt.subplot(224)
for a in A:
    plt.plot(X,chi2(5).pdf(a*X),label=str(a))

lines = plt.gca().get_lines()
l1=lines[-1]
labelLine(l1,0.6,label=r'$Re=${}'.format(l1.get_label()),ha='left',va='bottom',align = False)
labelLines(lines[:-1],align=False)

plt.show()

1
@blujay sono contento che tu sia riuscito ad adattarlo alle tue esigenze. Aggiungerò questo vincolo come problema.
NauticalMile

1
@Liza Leggi il mio Gotcha che ho appena aggiunto per spiegare perché sta accadendo. Nel tuo caso (presumo sia come quello in questa domanda ) a meno che tu non voglia creare manualmente un elenco di xvals, potresti voler modificare labelLinesun po 'il codice: cambia il codice nell'ambito if xvals is None:per creare un elenco basato su altri criteri. Potresti iniziare conxvals = [(np.min(l.get_xdata())+np.max(l.get_xdata()))/2 for l in lines]
NauticalMile

1
@Liza Il tuo grafico però mi intriga. Il problema è che i tuoi dati non sono distribuiti uniformemente sul grafico e hai molte curve che sono quasi una sopra l'altra. Con la mia soluzione potrebbe essere molto difficile distinguere le etichette in molti casi. Penso che la soluzione migliore sia avere blocchi di etichette impilate in diverse parti vuote della trama. Vedi questo grafico per un esempio con due blocchi di etichette impilate (un blocco con 1 etichetta e un altro blocco con 4). L'implementazione di questo sarebbe un bel po 'di lavoro di gamba, potrei farlo ad un certo punto in futuro.
NauticalMile

1
Nota: da Matplotlib 2.0 .get_axes()e .get_axis_bgcolor()sono stati deprecati. Si prega di sostituire con .axese .get_facecolor(), risp.
Jiāgěng

1
Un'altra cosa fantastica labellinesè che le proprietà sono correlate plt.texto si ax.textapplicano ad esso. Significa che è possibile impostare fontsizee bboxparametri nella labelLines()funzione.
tionichm

52

La risposta di @Jan Kuiken è certamente ben ponderata e approfondita, ma ci sono alcuni avvertimenti:

  • non funziona in tutti i casi
  • richiede una discreta quantità di codice extra
  • può variare notevolmente da un appezzamento all'altro

Un approccio molto più semplice consiste nell'annotare l'ultimo punto di ogni trama. Il punto può anche essere cerchiato, per enfasi. Questo può essere ottenuto con una riga aggiuntiva:

from matplotlib import pyplot as plt

for i, (x, y) in enumerate(samples):
    plt.plot(x, y)
    plt.text(x[-1], y[-1], 'sample {i}'.format(i=i))

Una variante potrebbe essere quella di utilizzare ax.annotate.


1
+1! Sembra una soluzione carina e semplice. Scusa per la pigrizia, ma come sarebbe questo? Il testo sarebbe all'interno del grafico o in cima all'asse y di destra?
rocarvaj

1
@rocarvaj Dipende da altre impostazioni. È possibile che le etichette sporgano fuori dal riquadro del grafico. Due modi per evitare questo comportamento sono: 1) utilizzare un indice diverso da -1, 2) impostare i limiti degli assi appropriati per lasciare spazio per le etichette.
Ioannis Filippidis

1
Diventa anche un pasticcio, se le trame si concentrano su un valore y - gli endpoint diventano troppo vicini perché il testo abbia un bell'aspetto
LazyCat

@LazyCat: è vero. Per risolvere questo problema, è possibile rendere trascinabili le annotazioni. Un po 'un dolore immagino, ma farebbe il trucco.
PlacidLush

1

Un approccio più semplice come quello di Ioannis Filippidis:

import matplotlib.pyplot as plt
import numpy as np

# evenly sampled time at 200ms intervals
tMin=-1 ;tMax=10
t = np.arange(tMin, tMax, 0.1)

# red dashes, blue points default
plt.plot(t, 22*t, 'r--', t, t**2, 'b')

factor=3/4 ;offset=20  # text position in view  
textPosition=[(tMax+tMin)*factor,22*(tMax+tMin)*factor]
plt.text(textPosition[0],textPosition[1]+offset,'22  t',color='red',fontsize=20)
textPosition=[(tMax+tMin)*factor,((tMax+tMin)*factor)**2+20]
plt.text(textPosition[0],textPosition[1]+offset, 't^2', bbox=dict(facecolor='blue', alpha=0.5),fontsize=20)
plt.show()

codice python 3 su sageCell

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.