Python >> Tutoriel Python >  >> Python Tag >> Keras

Couche conditionnelle du modèle CNN dans Keras

Le problème des conditions dans les réseaux de neurones

Le problème avec un commutateur ou des conditions (comme si-alors-sinon) dans le cadre d'un réseau de neurones est que les conditions ne sont pas différenciables partout. Par conséquent, les méthodes de différenciation automatique ne fonctionneraient pas directement et résoudre ce problème est extrêmement complexe. Cochez ceci pour plus de détails.

Un raccourci est que vous pouvez finir par former 3 modèles distincts indépendamment, puis utiliser pendant l'inférence un flux de contrôle de conditions pour en déduire.

#Training - 
model1 = model.fit(all images, P(cat/dog))
model2 = model.fit(all images, P(cat))
model3 = model.fit(all images, P(dog))
final prediction = argmax(model2, model3)

#Inference - 
if model1.predict == Cat: 
    model2.predict
else:
    model3.predict

Mais je ne pense pas que tu cherches ça. Je pense que vous cherchez à inclure des conditions dans le cadre du graphe de calcul lui-même.

Malheureusement, il n'y a aucun moyen direct pour vous de construire une condition si-alors dans le cadre d'un graphique de calcul selon ma connaissance. Le keras.switch que vous voyez vous permet de travailler avec des sorties de tenseur mais pas avec des couches d'un graphique pendant la formation. C'est pourquoi vous le verrez être utilisé dans le cadre des fonctions de perte et non dans les graphiques de calcul (lance des erreurs d'entrée).

Une solution possible – Ignorer les connexions et la commutation logicielle

Vous pouvez cependant essayer de construire quelque chose de similaire avec skip connections et soft switching .

Une connexion de saut est une connexion d'une couche précédente à une autre couche qui vous permet de transmettre des informations aux couches suivantes. Ceci est assez courant dans les réseaux très profonds où les informations des données d'origine sont ensuite perdues. Vérifiez U-net ou Resnet par exemple, qui utilise des connexions de saut entre les couches pour transmettre des informations aux couches futures.

Le problème suivant est la question de la commutation. Vous souhaitez basculer entre 2 chemins possibles dans le graphe. Ce que vous pouvez faire est une méthode de commutation douce que je me suis inspirée de cet article. Notez que pour switch entre 2 distributions de mots (une du décodeur et une autre de l'entrée), les auteurs les multiplient par p et (1-p) pour obtenir une distribution cumulée. Il s'agit d'un commutateur logiciel qui permet au modèle de sélectionner le prochain mot prédit à partir du décodeur ou de l'entrée elle-même. (aide lorsque vous voulez que votre chatbot prononce les mots saisis par l'utilisateur dans le cadre de sa réponse !)

Avec une compréhension de ces 2 concepts, essayons de construire intuitivement notre architecture.

  1. Nous avons d'abord besoin d'un graphique à entrée unique et sorties multiples puisque nous entraînons 2 modèles

  2. Notre premier modèle est une classification multi-classes qui prédit les probabilités individuelles pour Cat et Dog séparément. Cela sera formé avec l'activation de softmax et un categorical_crossentropy perte.

  3. Ensuite, prenons le logit qui prédit la probabilité de Cat et multiplions la couche de convolution 3 avec celui-ci. Cela peut être fait avec un Lambda calque.

  4. Et de même, prenons la probabilité de Dog et multiplions-la par la couche de convolution 2. Cela peut être vu comme suit -

    • Si mon premier modèle prédit parfaitement un chat et non un chien, alors le calcul sera 1*(Conv3) et 0*(Conv2) .
    • Si le premier modèle prédit parfaitement un chien et non un chat, alors le calcul sera 0*(Conv3) et 1*(Conv2)
    • Vous pouvez considérer cela comme un soft-switch OU un forget gate de LSTM. Le forget gate est une sortie sigmoïde (0 à 1) qui multiplie l'état de la cellule pour la déclencher et permettre au LSTM d'oublier ou de se souvenir des pas de temps précédents. Concept similaire ici !
  5. Ces Conv3 et Conv2 peuvent maintenant être traitées, aplaties, concaténées et transmises à une autre couche Dense pour la prédiction finale.

De cette façon, si le modèle n'est pas sûr d'un chien ou d'un chat, les caractéristiques conv2 et conv3 participent aux prédictions du second modèle. Voici comment vous pouvez utiliser skip connections et soft switch mécanisme inspiré pour ajouter une certaine quantité de flux de contrôle conditionnel à votre réseau.

Vérifiez mon implémentation du graphe de calcul ci-dessous.

from tensorflow.keras import layers, Model, utils
import numpy as np

X = np.random.random((10,500,500,3))
y = np.random.random((10,2))

#Model
inp = layers.Input((500,500,3))

x = layers.Conv2D(6, 3, name='conv1')(inp)
x = layers.MaxPooling2D(3)(x)

c2 = layers.Conv2D(9, 3, name='conv2')(x)
c2 = layers.MaxPooling2D(3)(c2)

c3 = layers.Conv2D(12, 3, name='conv3')(c2)
c3 = layers.MaxPooling2D(3)(c3)

x = layers.Conv2D(15, 3, name='conv4')(c3)
x = layers.MaxPooling2D(3)(x)

x = layers.Flatten()(x)
out1 = layers.Dense(2, activation='softmax', name='first')(x)

c = layers.Lambda(lambda x: x[:,:1])(out1)
d = layers.Lambda(lambda x: x[:,1:])(out1)

c = layers.Multiply()([c3, c])
d = layers.Multiply()([c2, d])

c = layers.Conv2D(15, 3, name='conv5')(c)
c = layers.MaxPooling2D(3)(c)
c = layers.Flatten()(c)

d = layers.Conv2D(12, 3, name='conv6')(d)
d = layers.MaxPooling2D(3)(d)
d = layers.Conv2D(15, 3, name='conv7')(d)
d = layers.MaxPooling2D(3)(d)
d = layers.Flatten()(d)

x = layers.concatenate([c,d])
x = layers.Dense(32)(x)
out2 = layers.Dense(2, activation='softmax',name='second')(x)

model = Model(inp, [out1, out2])
model.compile(optimizer='adam', loss='categorical_crossentropy', loss_weights=[0.5, 0.5])

model.fit(X, [y, y], epochs=5)

utils.plot_model(model, show_layer_names=False, show_shapes=True)
Epoch 1/5
1/1 [==============================] - 1s 1s/step - loss: 0.6819 - first_loss: 0.7424 - second_loss: 0.6214
Epoch 2/5
1/1 [==============================] - 0s 423ms/step - loss: 0.6381 - first_loss: 0.6361 - second_loss: 0.6400
Epoch 3/5
1/1 [==============================] - 0s 442ms/step - loss: 0.6137 - first_loss: 0.6126 - second_loss: 0.6147
Epoch 4/5
1/1 [==============================] - 0s 434ms/step - loss: 0.6214 - first_loss: 0.6159 - second_loss: 0.6268
Epoch 5/5
1/1 [==============================] - 0s 427ms/step - loss: 0.6248 - first_loss: 0.6184 - second_loss: 0.6311


Prochain article
No