Ho messo insieme alcuni esempi di codice tensorflow per aiutare a spiegare (il codice completo e funzionante è in questa sintesi ). Questo codice implementa la rete di capsule dalla prima parte della sezione 2 nel documento che hai collegato:
N_REC_UNITS = 10
N_GEN_UNITS = 20
N_CAPSULES = 30
# input placeholders
img_input_flat = tf.placeholder(tf.float32, shape=(None, 784))
d_xy = tf.placeholder(tf.float32, shape=(None, 2))
# translate the image according to d_xy
img_input = tf.reshape(img_input_flat, (-1, 28, 28, 1))
trans_img = image.translate(img_input, d_xy)
flat_img = tf.layers.flatten(trans_img)
capsule_img_list = []
# build several capsules and store the generated output in a list
for i in range(N_CAPSULES):
# hidden recognition layer
h_rec = tf.layers.dense(flat_img, N_REC_UNITS, activation=tf.nn.relu)
# inferred xy values
xy = tf.layers.dense(h_rec, 2) + d_xy
# inferred probability of feature
p = tf.layers.dense(h_rec, 1, activation=tf.nn.sigmoid)
# hidden generative layer
h_gen = tf.layers.dense(xy, N_GEN_UNITS, activation=tf.nn.relu)
# the flattened generated image
cap_img = p*tf.layers.dense(h_gen, 784, activation=tf.nn.relu)
capsule_img_list.append(cap_img)
# combine the generated images
gen_img_stack = tf.stack(capsule_img_list, axis=1)
gen_img = tf.reduce_sum(gen_img_stack, axis=1)
Qualcuno sa come dovrebbe funzionare la mappatura tra i pixel di input e le capsule?
Questo dipende dalla struttura della rete. Per il primo esperimento in quel documento (e il codice sopra), ogni capsula ha un campo ricettivo che include l'intera immagine di input. Questa è la disposizione più semplice. In tal caso, è uno strato completamente collegato tra l'immagine in ingresso e il primo livello nascosto in ogni capsula.
In alternativa, i campi ricettivi della capsula possono essere disposti più come i gherigli della CNN con i passi, come negli esperimenti successivi in quel documento.
Cosa dovrebbe accadere esattamente nelle unità di riconoscimento?
Le unità di riconoscimento sono una rappresentazione interna che ogni capsula ha. Ogni capsula utilizza questa rappresentazione interna per calcolare p
, la probabilità che sia presente la funzione della capsula e xy
i valori di traduzione dedotti. La figura 2 in quel documento è un controllo per assicurarsi che la rete stia imparando a usare xy
correttamente (lo è).
Come dovrebbe essere addestrato? È solo un puntello posteriore standard tra ogni connessione?
In particolare, è necessario addestrarlo come un codificatore automatico, utilizzando una perdita che applica la somiglianza tra l'output generato e l'originale. Errore quadratico medio funziona bene qui. A parte questo, sì, dovrai propagare la discesa del gradiente con backprop.
loss = tf.losses.mean_squared_error(img_input_flat, gen_img)
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)