Come funziona python numpy.where ()?


94

Sto giocando numpye scavando nella documentazione e mi sono imbattuto in un po 'di magia. Vale a dire sto parlando di numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

Come fanno a ottenere internamente che tu sia in grado di passare qualcosa di simile x > 5a un metodo? Immagino abbia qualcosa a che fare con, __gt__ma sto cercando una spiegazione dettagliata.

Risposte:


75

Come fanno a ottenere internamente che tu sia in grado di passare qualcosa come x> 5 in un metodo?

La risposta breve è che non lo fanno.

Qualsiasi tipo di operazione logica su un array numpy restituisce un array booleano. (cioè __gt__, __lt__ecc. tutti restituiscono array booleani in cui la condizione data è vera).

Per esempio

x = np.arange(9).reshape(3,3)
print x > 5

rende:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

Questa è la stessa ragione per cui qualcosa come if x > 5:solleva xun'eccezione ValueError se è un array numpy. È un array di valori True / False, non un singolo valore.

Inoltre, gli array numpy possono essere indicizzati da array booleani. Ad esempio x[x>5]rendimenti [6 7 8], in questo caso.

Onestamente, è abbastanza raro che tu abbia effettivamente bisogno, numpy.wherema restituisce solo le indicazioni su dove si trova un array booleano True. Di solito puoi fare quello che ti serve con una semplice indicizzazione booleana.


10
Giusto per sottolineare che numpy.wherehanno 2 'modalità operative', il primo restituisce indices, dove condition is Truee se sono presenti parametri opzionali xe y(stessa forma di condition, o trasmettibile a tale forma!), Restituirà valori da xquando condition is Truealtrimenti da y. Quindi questo rende wherepiù versatile e consente di essere utilizzato più spesso. Grazie
mangia il

1
In alcuni casi, l'utilizzo di __getitem__ sintassi di []over numpy.whereo numpy.take. Poiché __getitem__deve supportare anche lo slicing, c'è un po 'di overhead. Ho notato notevoli differenze di velocità quando si lavora con le strutture dati di Python Pandas e si indicizzano logicamente colonne molto grandi. In quei casi, se non hai bisogno di affettare, allora takee wheresono effettivamente migliori.
ely

24

Vecchia risposta è un po 'confusa. Ti dà le POSIZIONI (tutte) in cui la tua dichiarazione è vera.

così:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Lo uso come alternativa a list.index (), ma ha anche molti altri usi. Non l'ho mai usato con array 2D.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Nuova risposta Sembra che la persona stesse chiedendo qualcosa di più fondamentale.

La domanda era come potresti implementare qualcosa che consente a una funzione (come dove) di sapere cosa è stato richiesto.

Prima nota che chiamare uno qualsiasi degli operatori di confronto fa una cosa interessante.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Questo viene fatto sovraccaricando il metodo "__gt__". Per esempio:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

Come puoi vedere, "a> 4" era un codice valido.

Puoi ottenere un elenco completo e la documentazione di tutte le funzioni sovraccariche qui: http://docs.python.org/reference/datamodel.html

Qualcosa di incredibile è quanto sia semplice farlo. TUTTE le operazioni in Python vengono eseguite in questo modo. Dire a> b è equivalente ad a.gt (b)!


3
Questo sovraccarico di operatori di confronto non sembra funzionare bene con espressioni logiche più complesse, ad esempio non posso farlo np.where(a > 30 and a < 50)o np.where(30 < a < 50)perché finisce per provare a valutare l'AND logico di due array di booleani, il che è piuttosto privo di significato. C'è un modo per scrivere una tale condizione np.where?
davidA

@meowsqueaknp.where((a > 30) & (a < 50))
tibalt

Perché np.where () restituisce un elenco nel tuo esempio?
Andreas Yankopolus

0

np.where restituisce una tupla di lunghezza uguale alla dimensione del numpy ndarray su cui è chiamato (in altre parole ndim ) e ogni elemento della tupla è un numpy ndarray di indici di tutti quei valori nel ndarray iniziale per cui la condizione è True. (Si prega di non confondere la dimensione con la forma)

Per esempio:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y è una tupla di lunghezza 2 perché x.ndim è 2. Il primo elemento nella tupla contiene i numeri di riga di tutti gli elementi maggiori di 4 e il secondo elemento contiene i numeri di colonna di tutti gli elementi maggiori di 4. Come puoi vedere, [1,2,2 , 2] corrisponde ai numeri di riga di 5,6,7,8 e [2,0,1,2] corrisponde ai numeri di colonna di 5,6,7,8 Si noti che il ndarray è attraversato lungo la prima dimensione (riga per riga ).

Allo stesso modo,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


restituirà una tupla di lunghezza 3 perché x ha 3 dimensioni.

Ma aspetta, c'è di più da np. Dove!

quando vengono aggiunti due argomenti aggiuntivi a np.where; eseguirà un'operazione di sostituzione per tutte quelle combinazioni riga-colonna a coppie ottenute dalla tupla sopra.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.