Verifica se l'array numpy contiene solo zeri


92

Inizializziamo un array numpy con zeri come sotto:

np.zeros((N,N+1))

Ma come controlliamo se tutti gli elementi in una data matrice di array n * n numpy sono zero.
Il metodo deve solo restituire un True se tutti i valori sono effettivamente zero.

Risposte:



161

Le altre risposte pubblicate qui funzioneranno, ma la funzione più chiara ed efficiente da utilizzare è numpy.any():

>>> all_zeros = not np.any(a)

o

>>> all_zeros = not a.any()
  • Questo è preferibile numpy.all(a==0)perché utilizza meno RAM. (Non richiede l'array temporaneo creato dal a==0termine.)
  • Inoltre, è più veloce di numpy.count_nonzero(a)perché può tornare immediatamente quando è stato trovato il primo elemento diverso da zero.
    • Modifica: come ha sottolineato @Rachel nei commenti, np.any()non usa più la logica "cortocircuito", quindi non vedrai un vantaggio in termini di velocità per i piccoli array.

2
Fino a un minuto fa, numpy anye nonall vanno in cortocircuito. Credo che siano zucchero per e . Confronta tra loro e il mio cortocircuito : logical_or.reducelogical_and.reduceis_inall_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

2
Questo è un ottimo punto, grazie. Sembra corto circuito utilizzato per essere il comportamento, ma che è stato perso a un certo punto. C'è qualche discussione interessante nelle risposte a questa domanda .
Stuart Berg

50

Userei np.all qui, se hai un array a:

>>> np.all(a==0)

3
Mi piace che questa risposta controlli anche valori diversi da zero. Ad esempio, si può verificare se tutti gli elementi in un array sono gli stessi facendo np.all(a==a[0]). Molte grazie!
aignas

9

Come dice un'altra risposta, puoi trarre vantaggio da valutazioni veritiere / false se sai che 0è l'unico elemento falso possibilmente nel tuo array. Tutti gli elementi in un array sono falsi se e solo se non ci sono elementi veritieri in esso. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

Tuttavia, la risposta affermava che anyera più veloce di altre opzioni a causa in parte del cortocircuito. A partire dal 2018, Numpy alle any non cortocircuitano .

Se fai spesso questo genere di cose, è molto facile creare le tue versioni in cortocircuito usando numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Questi tendono ad essere più veloci delle versioni di Numpy anche quando non sono in cortocircuito. count_nonzeroè il più lento.

Alcuni input per controllare le prestazioni:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Dai un'occhiata:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Utile alled anyequivalenze:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-9

Se stai testando tutti gli zeri per evitare un avviso su un'altra funzione numpy, avvolgere la linea in una prova, tranne il blocco salverà dover fare il test per gli zeri prima dell'operazione a cui sei interessato, ad es.

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.