Inference Variationnelle et VAE

Ce post introduit de zéro les auto-encodeurs variationnels et donc les bases de l'inférence variationnelle. Nous démarrons avec le concept d'estimation de modèles probabilistes et les notions de paramètres et de variables latentes.

# Estimation de modèles probabilistes

On cherche les paramètres que l'on appelle (un ensemble de variables) et les variables latentes que l'on appelle d'un modèle probabiliste. En gros, ici on dira que les paramètres sont des choses (inconnues) globales et les variables latentes des choses (inconnues) qui dépendent des points.

EX: dans le cas très simple où l'on fait la supposition que nous avons des points tirés d'une gaussienne (loi normale), les paramètres sont la moyenne et variance/covariance de la gaussienne et est vide.

On représente généralement cela de la façon suivante (où le rectangle représente la répétition pour chaque point du dataset, et l'absence de flèche représente l'indépendance conditionnelle entre variables) :

EX: dans un GMM (mélange de gaussiennes) sont les moyennes, variances (et poids des gaussiennes, mais on va considérer des poids égaux pour simplifier les explications), alors que est l'affectation du i-ème point à une gaussienne donnée

Dans certaines situations, ou à certains moments d'un raisonnement, le dataset est figé donc en fait il n'y a souvent aucune différence entre les deux concepts. On parle alors indifféremment de variable ou paramètres, pour

NB: Par contre, dans un VAE, on va faire de l'amortissement (cf bien plus bas), donc on traitera les différemment et qui plus est, on ignorera souvent les qui seront parfois cachés dans les dérivations.

La tâche, en général, est de calculer ce que nos observations nous disent sur nos variables inconnues et , plus précisément nous voulons estimer la probabilité de nos variables sachant (étant données) les données que nous avons observées, c'est à dire .

Dans pas mal de cas, on s'intéresse vraiment à cette distribution sur les paramètres, mais parfois on s'intéresse à trouver juste une valeur pour ces paramètres, comme pour le MAP (maximum a posteriori) ou le MLE (maximum likelihood estimator).

Si on fait du MAP (maximum a posteriori), on cherche juste le max (de cette probabilité inconnue), c'est à dire . On a aussi le MLE (maximum likelihood estimator) qui, lui, considère la vraisemblance (ceci est un peu bizarre si on prend le temps d'y réfléchir... conceptuellement il parait étrange de maximiser une probabilité sur sa condition et non sur le domaine de cette distribution). On peut lier les deux par la règle des probabilités conditionnelles (Bayes) : . Comme on optimise sur et , est en fait constant et peut être ignoré d'une certaine façon. Du coup seul (qui est un a priori sur nos variables inconnues) s'ajoute pour le MAP, donc en gros, MAP est comme MLE mais avec un a priori... ou vu autrement MLE est un MAP avec un prior non informatif (uniforme, si cela existe).

On peut représenter le modèle avec son a priori comme ça :

# Utilisation d'un modèle probabiliste

Si on a trouvé des valeurs de et sur un ensemble de données, on peut très bien prendre un nouvel ensemble de données, garder et trouver les meilleurs (qui dépendent des points) qui correspondent à ces nouveaux points, en travaillant sur . On peut aussi garder nos anciens points et tout refaire avec l'union des deux ensembles de données. Ou encore, on peut résumer tous les anciens points dans un nouveau prior sur (voir sur selon sa forme, ce que l'on nomme parfois « posterior predictive »).

EX: pour un GMM, avec un MAP (ou MLE), on a trouvé notre mélange de gaussienne (on avait aussi qui nous disait à quelle gaussiene chaque point appartient). On peut alors, pour un nouveau point , chercher la distribution sur correspondante (pas juste la meilleure valeur) qui nous donne à quel point ce nouveau point appartient à chaque gaussienne (on peut quand même prendre le si on veux l'affecter à une composante/cluster).

# Inférence Variationnelle de base

Dans l'inférence variationnelle, on considère toujours , mais on ne veut pas juste trouver le MAP, on veut vraiment la distribution, mais le problème est souvent compliqué. Du coup, on va chercher les paramètres d'une distribution qui visera à approximer (inconnu, induit par nos données ), et on notera cette distribution approximante ( sont les paramètres qui contrôlent , et sont parfois omis dans la notation).

NB: on suppose souvent une sorte d'indépendance dans la distribution approximante, et donc la distribution se décompose en produit de deux termes notés abusément :

On peut informellement représenter cela comme ça :

Le but de l'inférence variationnelle est alors de minimiser la « distance » entre la vraie (inconnue) et notre meilleure approximation (et donc de trouver les meilleurs paramètres ). La formulation est de dire que l'on va minimiser la KL divergence (Kullback–Leibler divergence) : ou, sans raccourci de notation,

EX: pour un GMM, on a beaucoup de paramètres (les moyennes et variances de nos composantes) et encore plus de (l'affectation des points). doit être une distribution sur ces paramètres, donc une distribution sur les moyennes et les variances et une distribution sur les affectations. Pour les variables continues (comme la moyenne et variance) on utilise souvent une gaussienne (donc une moyenne et une variance pour chaque paramètre) et pour les variables discrètes (affectation ) on utilise une loi catégorique (pour chaque point, on a donc poids s'il y a composantes)... ce qui est énorme s'il y a beaucoup de points.

Dans ce mode d'inférence, on a très souvent plus de paramètres que de variables dans . Cela dit, une estimée ponctuelle (un seul jeu de valeur) pour le vecteur correspond à une distribution sur ... donc « on y gagne », mais ça peut rester un gros problème surtout si est grand (grand nombre de points).

On peut faire quelques développements pour ré-écrire la fonction que l'on optimise :

(minimiser)

est la « neg-entropie » (opposé de l'entropie) de la distribution .

On peut retourner l'autre terme :

Et développer le log :

Comme le dataset est fixe, est une constante que l'on pourra ignorer. Le premier terme est la log-vraisemblance et le second un log-prior sur les paramètres. En récapitulant, on trouvera que la optimisée en inférence variation est faite de ces deux termes (moyennés sur les valeurs de venant de ) et de l'entropie de :

(à la louche : on veut maximiser l'entropie (rendre la distribution plutôt uniforme), maximiser l'accord entre et le prior, et maximiser la log-vraisemblance)

EX: pour une distribution normale simple de moyenne et d'écart type , l'entropie est (cf wikipedia)

On peut aussi re-regrouper les premiers termes (on aurait pu le faire dès le départ) :

qui donne aussi

qui s'interprète assez bien en tant que : minimiser la KL entre et le prior, tout en maximisant la log-vraisemblance (en moyenne sur ).

EX: pour une distribution qui est et un prior , on a une KL qui vaut (cf wikipedia, qui se démontre en développement la formule) ... que l'on retrouve dans les VAE typiques.

# Quelle liberté avons-nous en restant en mode « Variational Inference » ?

En restant dans le formalisme et sans rien changer, il est possible de choisir la forme de et aussi le prior .

On peut, plus difficilement je dirais, se dire que l'on va optimiser autre chose que (qui a ses défauts) (par exemple en mode wasserstein ? , donne Wasserstein Variational Inference ?) mais il faut arriver à faire des dérivations similaires (ce qui n'est pas du tout du tout gagné).

# Que sont donc les VAE ?

# Préliminaires (ou rappels) : Auto-Encodeur non-variationnel

Un auto-encodeur (AE) est un modèle (par ex, un réseau de neurones) symétrique, dans le sens où il a autant de neurone sur la couche de sortie que sur la couche d'entrée. Un AE peut être imaginé comme fait de deux parties : une première moitié qu'on appelle « encodeur » et une deuxième moitié qu'on appelle « décodeur ». On utilise généralement les AE en apprentissage non-supervisé : on leur donne le dataset d'apprentissage sans étiquette, en passant la même valeur à la fois en entrée et en tant que sortie attendue. Lors de l'entraînement, qui vise à apprendre à la fois l'encodeur et le décodeur, on minimise généralement l'erreur de reconstruction, c'est à dire le carré de la distance euclidienne entre la sortie de l'auto-encodeur et la sortie attendue (qui est aussi son entrée).

wikipedia

Quel intérêt d'apprendre à un réseau de neurones à re-générer des données que l'on a déjà ? Tout dépend de la forme du réseau : si, comme dans la plupart des AE, il a une forme d'entonnoir, avec une couche centrale de taille beaucoup plus petite que celle des entrées (et donc aussi des sorties), alors la première partie du réseau (encoder) va apprendre à projeter les entrées dans un espace latent (la couche du milieu) plus petit, donc à "compresser" les données (i.e. à trouver les éléments importants à conserver et ceux, moins important à jeter) et la deuxième partie (decoder) va apprendre à reconstruire le mieux possible les données originales à partir de leur représentation latente/compressée. En général, on va choisir le nombre de neurones sur la couche du milieu plus petit que la taille des couches d'entrée/sortie, pour compresser les données. Cependant, on pourrait aussi choisir une taille de la couche latente/du milieu plus grande que les couches d'entrée/sortie, ce qui forcerait l'AE à chercher une projection dans un espace de plus grande dimension. Dans certains cas, cela peut être utile, par exemple pour chercher une séparation linéaire de ces données dans ce plus grand espace (un peu comme le font les SVM). Dans de tels cas, il faut faire attention 1) à ce que l'AE n'apprenne pas la fonction identité (en fixant ses poids à 1 partout = recopier les entrées sur les sorties) et 2) à ce que l'AE n'apprenne pas par cœur les données (puisqu'il a "plus" de place dans son espace latent que dans l'espace s'entrée). Ce dernier cas peut être résolu en ajoutant de la régularisation (drop out ou terme additionnel dans la fonction de loss).

# Décodeur

Par rapport à l'inférence variationnelle, la partie la plus directe d'un VAE est le décodeur (qui va traiter indépendamment chaque point). Étant donnée une représentation latente (quelconque) d'un seul point , le décodeur (combiné à la loss) est responsable d'évaluer . Le fait que le décodeur ne traite qu'un seul point revient à dire que la vraisemblance de (sachant et ) se décompose (indépendance conditionnelle), autrement dit , ce qui est la supposition de plein de modèles à variables latentes (type GMM).

EX: dans un VAE, quand on utilise la distance euclidienne au carré entre le point et le point reconstruit (, aussi appelée erreur de reconstruction), en fait, ce que l'on dit est que le de décodeur prédit la moyenne d'une loi normale, que l'on peut formaliser comme (la loi centrée en que l'on évalue en ). Comme on minimise une log proba (cf l'inférence variationnelle), on retombe sur la norme 2 au carré () avec (à une constante additive près). On a pris une variance de 1 ci-dessus, mais avec les calculs on retrouve que cette variance contrôle au final le tradeoff entre reconstruction et régularisation.

NB: si on utilise une sigmoid en fin de décodeur et une « cross entropy loss », en fait, cela revient à supposer que le décodeur produit le/les paramètres d'une loi de bernoulli (et on optimise alors toujours bien la log-proba .

NB: on pourrait ici explicitement avoir d'autres distributions, plus heavy-tailed, comme une laplace (qui correspond à une pénalité en norme )

En résumé, le décodeur (+ la loss du réseau, que l'on peut choisir) est un modèle de la vraisemblance, paramétré par (l'architecture et) les poids d'un réseau de neurones. Les poids (paramètres) du réseau cachent donc ce que l'on appelait précédemment (i.e. ).

# Encodeur

Qu'est ce que donc l'encodeur dans tout ça ? Nous allons voir ce qu'est l'amortissement (amortization) en inférence variationnelle.

En fait, l'encodeur est quelque chose de totalement différent. Dans l'inférence variationnelle, un problème récurrent est celui du nombre de paramètres qui dépend du nombre de points observés.

EX: dans un GMM à 10 composantes, si on a 1M de points, alors contient 1M de valeurs (l'affectation de chacun des 1M de points à une composante) et si on veut mettre une distribution sur ces , alors elle aura par exemple 1M × 9 paramètres (9 paramètres pour définir les proba de tirage d'un « dé » à 10 faces).

L'amortissement en inférence variationnelle vise à réduire ce nombre de paramètres et surtout à faire qu'il ne dépende pas du nombre de points. Il permet de procurer une régularité et une capacité d'extrapolation en se basant sur l'espace des observations (). En inférence variationnelle classique, pour chaque point , un jeu de paramètres sert à représenter la distribution de . L'amortissement consiste à apprendre une unique fonction qui, à partir de n'importe quel point arrive à prédire une distribution sur . Les paramètres (qui dépendaient du nombre de points) sont remplacés par un jeu de paramètres unique qui défini la fonction qui associe une distribution sur l'espace latent pour n'importe quel point d'entrée.

L'encodeur d'un VAE n'est autre que la fonction issue de l'amortissement. On a plus formellement, pour n'importe quel (ou même n'importe quel point de l'espace d'entrée, même s'il n'est pas dans le dataset) .

NB: très souvent dans un VAE, la distribution variationnelle est choisie comme étant une loi normale, nous avons alors un encodeur qui prédit une moyenne et une variance (ou covariance), qui définit la distribution sur l'espace latent

# Inférence doublement stochastique

Un autre article pourrait être dédié à ces points mais voici une introduction à l'apprentissage des VAE.

L'apprentissage d'un VAE est très souvent doublement stochastique.

Un premier niveau stochastique est le fait de traiter un sous-ensemble (minibatch) des points à chaque fois (généralisation du SGD aux minibatchs). On peut considérer un minibatch de taille 1 pour le raisonnement.

Un second niveau stochastique est d'approximer l'espérance sur à l'aide d'un unique tirage, on remplace donc par un approximateur de l'espérance qui utilise un seul tirage et donne donc avec

# Reparametrization trick

Comme l'étape de sampling est non-différentiable, on utilise le fait que tirer de est équivalent à tirer de puis à multiplier par et ajouter , qui fait que le sampling est « isolé » et que le gradient peut bien passer. (à détailler au besoin, mais c'est un trick pour l'implémentation, relativement moins important)