Voglio spiegare la semplice risposta con varie note di performance. np.linalg.norm farà forse più del necessario:
dist = numpy.linalg.norm(a-b)
Innanzitutto, questa funzione è progettata per lavorare su un elenco e restituire tutti i valori, ad esempio per confrontare la distanza pAdall'insieme di punti sP:
sP = set(points)
pA = point
distances = np.linalg.norm(sP - pA, ord=2, axis=1.) # 'distances' is a list
Ricorda diverse cose:
- Le chiamate alle funzioni Python sono costose.
- [Regular] Python non memorizza nella cache le ricerche dei nomi.
Così
def distance(pointA, pointB):
dist = np.linalg.norm(pointA - pointB)
return dist
non è così innocente come sembra.
>>> dis.dis(distance)
2 0 LOAD_GLOBAL 0 (np)
2 LOAD_ATTR 1 (linalg)
4 LOAD_ATTR 2 (norm)
6 LOAD_FAST 0 (pointA)
8 LOAD_FAST 1 (pointB)
10 BINARY_SUBTRACT
12 CALL_FUNCTION 1
14 STORE_FAST 2 (dist)
3 16 LOAD_FAST 2 (dist)
18 RETURN_VALUE
In primo luogo - ogni volta che lo chiamiamo, dobbiamo fare una ricerca globale per "np", una ricerca con ambito per "linalg" e una ricerca con ambito per "norma", e il sovraccarico di semplicemente chiamare la funzione può equivalere a dozzine di pitone Istruzioni.
Infine, abbiamo sprecato due operazioni per memorizzare il risultato e ricaricarlo per il ritorno ...
Primo passaggio al miglioramento: velocizza la ricerca, salta il negozio
def distance(pointA, pointB, _norm=np.linalg.norm):
return _norm(pointA - pointB)
Otteniamo il più snello:
>>> dis.dis(distance)
2 0 LOAD_FAST 2 (_norm)
2 LOAD_FAST 0 (pointA)
4 LOAD_FAST 1 (pointB)
6 BINARY_SUBTRACT
8 CALL_FUNCTION 1
10 RETURN_VALUE
Tuttavia, l'overhead della chiamata di funzione equivale ancora a qualche lavoro. E vorrai fare dei benchmark per determinare se potresti fare meglio la matematica da solo:
def distance(pointA, pointB):
return (
((pointA.x - pointB.x) ** 2) +
((pointA.y - pointB.y) ** 2) +
((pointA.z - pointB.z) ** 2)
) ** 0.5 # fast sqrt
Su alcune piattaforme, **0.5è più veloce di math.sqrt. Il tuo chilometraggio può variare.
**** Note sulle prestazioni avanzate.
Perché stai calcolando la distanza? Se l'unico scopo è visualizzarlo,
print("The target is %.2fm away" % (distance(a, b)))
andare avanti. Ma se si confrontano le distanze, si eseguono controlli di portata, ecc., Vorrei aggiungere alcune utili osservazioni sulle prestazioni.
Prendiamo due casi: ordinamento per distanza o abbattimento di un elenco per gli elementi che soddisfano un vincolo di intervallo.
# Ultra naive implementations. Hold onto your hat.
def sort_things_by_distance(origin, things):
return things.sort(key=lambda thing: distance(origin, thing))
def in_range(origin, range, things):
things_in_range = []
for thing in things:
if distance(origin, thing) <= range:
things_in_range.append(thing)
La prima cosa che dobbiamo ricordare è che stiamo usando Pitagora per calcolare la distanza ( dist = sqrt(x^2 + y^2 + z^2)), quindi stiamo facendo molte sqrtchiamate. Matematica 101:
dist = root ( x^2 + y^2 + z^2 )
:.
dist^2 = x^2 + y^2 + z^2
and
sq(N) < sq(M) iff M > N
and
sq(N) > sq(M) iff N > M
and
sq(N) = sq(M) iff N == M
In breve: fino a quando non richiediamo effettivamente la distanza in un'unità di X anziché X ^ 2, possiamo eliminare la parte più difficile dei calcoli.
# Still naive, but much faster.
def distance_sq(left, right):
""" Returns the square of the distance between left and right. """
return (
((left.x - right.x) ** 2) +
((left.y - right.y) ** 2) +
((left.z - right.z) ** 2)
)
def sort_things_by_distance(origin, things):
return things.sort(key=lambda thing: distance_sq(origin, thing))
def in_range(origin, range, things):
things_in_range = []
# Remember that sqrt(N)**2 == N, so if we square
# range, we don't need to root the distances.
range_sq = range**2
for thing in things:
if distance_sq(origin, thing) <= range_sq:
things_in_range.append(thing)
Fantastico, entrambe le funzioni non fanno più costose radici quadrate. Sarà molto più veloce. Possiamo anche migliorare in_range convertendolo in un generatore:
def in_range(origin, range, things):
range_sq = range**2
yield from (thing for thing in things
if distance_sq(origin, thing) <= range_sq)
Ciò ha in particolare dei vantaggi se stai facendo qualcosa del tipo:
if any(in_range(origin, max_dist, things)):
...
Ma se la prossima cosa che farai richiede una distanza,
for nearby in in_range(origin, walking_distance, hotdog_stands):
print("%s %.2fm" % (nearby.name, distance(origin, nearby)))
considera di dare tuple:
def in_range_with_dist_sq(origin, range, things):
range_sq = range**2
for thing in things:
dist_sq = distance_sq(origin, thing)
if dist_sq <= range_sq: yield (thing, dist_sq)
Ciò può essere particolarmente utile se si potrebbero verificare i controlli della catena ("trovare oggetti vicini a X e entro Nm di Y", poiché non è necessario calcolare nuovamente la distanza).
Ma cosa succede se stiamo cercando un elenco davvero ampio di thingse prevediamo che molti di loro non vadano presi in considerazione?
Esiste in realtà un'ottimizzazione molto semplice:
def in_range_all_the_things(origin, range, things):
range_sq = range**2
for thing in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
Se questo sia utile dipenderà dalla dimensione delle "cose".
def in_range_all_the_things(origin, range, things):
range_sq = range**2
if len(things) >= 4096:
for thing in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
elif len(things) > 32:
for things in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2 + (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
else:
... just calculate distance and range-check it ...
E ancora, considera di cedere dist_sq. Il nostro esempio di hot dog diventa quindi:
# Chaining generators
info = in_range_with_dist_sq(origin, walking_distance, hotdog_stands)
info = (stand, dist_sq**0.5 for stand, dist_sq in info)
for stand, dist in info:
print("%s %.2fm" % (stand, dist))