Problem
Visualizing the style-feature encoding of MNIST test dataset using an encoder model trained using train-aae
script gives the following result:
And generating images using the decoder model by sampling a random style-vector from a normal distribution (loc=0, scale=1):
Note that some digits are not constructed well. This is because, even though the overall distribution for each feature-component is nicely centered around zero, if we look at it separately for each digit, some of them are still skewed. This is (likely) because the discriminator is digit-agnostic, and thus can't enforce the distribution in a per-digit manner.
Solution
Make the discriminator digit-aware (basically using a simpler variant of the idea from section 2.3 of the paper). When training the discriminator:
- "Fake" inputs should be the one-hot representation of the label + the Encoder's style encoding output.
- "Real" inputs should be a random one-hot representation + a prior-distribution random-sampled style vector.