Posts How does OpenAI's DALL-E work?
Post
Cancel

How does OpenAI's DALL-E work?

Introduction

Before I say anything else, create an account and take DALL-E 2 out for a spin. It is an example of generative AI, which, as a field, has seen important, exciting, and viral breakthroughs during 2022. Here are some examples from the DALL E paper:

Examples of captions and generated images from the DALL-E paper.

To understand what generative AI is, let’s say you have images that have each been labeled as either an image that contains a cat, or an image that does not contain a cat. A (probabilistic) discriminative model is one that tries to extract some kind of meaningful information from images (such as the presence of certain types of edges or textures) and discriminate between cat and non-cat images. Most classification models implemented across industry are of this type. A generative model, on the other hand learns how to generate new (image, cat/no cat label) pairs. Ian Goodfellow, a pioneer of generative modeling, very succinctly explains why generative modeling is usually much harder to do well:

Can you look at a painting and recognize it as being the Mona Lisa? You probably can. That’s discriminative modeling. Can you paint the Mona Lisa yourself? You probably can’t. That’s generative modeling.

The most popular examples of generative AI that have emerged over the past year have been in the fields of computer vision (Stable Diffusion, DALL-E) and natural language processing (Chat GPT), and in this post, I want to try to provide a very high-level overview of how DALL-E learns to generate images from text prompts.

What we will not discuss (but should at least mention)

The “machine learning” that powers many of these large, impressive models is very often not the hardest part of developing them. There are computational considerations and optimizations that are critically important to making such models actually work, but I don’t think describing them here will add a lot of value for my readers, so we won’t be going into those details. (In fact, in the DALL-E paper, they explicitly acknowledge that “[g]etting the model to train in 16-bit precision past one billion paremeters, without diverging was the most challenging part of this project.”)

Another important thing to remember is that these amazing capabilities of these models emerge from mountains of high quality data and computing power. Even if the model code is made available to the public, training and using the models without sufficient ability to handle huge quantities of data and throw massive amounts of compute at the problem will result in useless models at best. There are some efforts to create smaller versions of these models that can be run and used by individuals, but so far they don’t seem as capable, in general, as their huge progenitors.

Disclaimer

This post is my high-level overview/summary of part of the DALL-E paper. It is possible that I misunderstood something or explained it incorrectly. If you come across any such mistakes, let me know so that I can learn from and correct them.

With that out of the way, let’s have a look under the hood.

High level overview

DALL-E learns using about 250 million (image, text) pairs. The learning process is roughly:

  1. Learn how to come up with useful (and compressed) representations of images
  2. Turn each prompt into a sequence of tokens
  3. Combine the representations of the images and the corresponding prompts
  4. Train a neural network to predict the next token of the combined representation given some of the previous tokens.

We will discuss each of these steps in turn.

Learning image representations

One of the most important, foundational ideas in deep learning is: neural networks like numbers (not images or text). In order to bring the full power of neural networks to bear on text and vision problems, we usually use approaches that are variations on this theme:

  1. Convert the image (or text) into a numerical representation such that the meaning/content of the image (or text) is preserved. (For example, in representation space, collections of numbers corresponding to images of cats should be closer to one another than they are to collections of numbers corresponding to images of galaxies.) We will refer to these as representations or embeddings.
  2. Use the representations for some task that we care about (e.g. discriminate between cat and non-cat images).

There are many approaches that accomplish step 1. In the case of DALL-E, OpenAI used an autoencoder. This architecture consists of two parts:

  1. Encoder: takes an image and produces a representation
  2. Decoder: takes the representation and tries to reproduce the image

In order to learn, the encoder and decoder use penalties incurred for differences between the original image and the decoder’s reconstruction to update their internal states. When training is complete, we can use the encoder to produce useful image representations for whatever other tasks we intend to carry out with the images as input. Intuitively, we can think about an autoencoder as a kind of compression algorithm. We take an image, compress it into a representation that is (1) smaller than the original image, but (2) such that we can (mostly) reconstruct the original image. If we can do this well, then we’ve come up with a representation that seems to carry much of the important information from the original image, which is what we wanted.

(Really, it uses the generative and more involved cousin of autoencoders called a variational encoder, but the high level idea of learning a compressed representation is the same.)

The autoencoder that DALL-E uses compresses 256x256 (dimensions in pixels) images into 32x32 representations. Each 256x256 image has 3 numbers associated with each pixel (red, green, and blue concentrations). Thus, each original image requires 256 * 256 * 3 = 196,608 numbers to represent it, whereas the representation only requires 32 * 32 = 1024 numbers. This is a compression factor of 192!

For a visual idea of how good these encodings are once the system is trained, here are some examples of original and reconstructed images:

Original (top) and reconstructed (bottom) images produced by the image representation learning system.

(Note: Representations of this kind are often continuous, meaning the numbers in each representation slot can be any real number. In this case, the encoding is discrete, which just means that the numbers in each representation slot are actually whole numbers instead.)

Encoding prompts

To encode the prompts, DALL-E uses byte-pair encoding, which can be more broadly be categorized as a tokenization method. Tokenization is a way of breaking down unstructured natural language into a finite set of meaningful atomic units. It can be carried out at the level of sentences, words, or even parts of words (e.g., “eating” might be broken into “eat” and “ing”). Each token in a limited vocabulary (e.g., 10k frequent words or subwords) is usually assigned an identifier – in this case, a number (there are usually also numbers reserved for unknown tokens, beginnings and ends of sentences, etc.). Once we’ve decided how we’re going to tokenize and have a vocabulary of identifiers, we can then process our input text and turn it into its corresponding tokenized representation.

(We can manually decide the level at which to tokenize text (e.g., words, sentences) or we can use machine learning to learn a good tokenization, but we aren’t going to go into detail about those approaches here.)

Jointly modeling text and image tokens

Once we have computed our image representations and tokenized our text, we glue the representations together into what is essentially a composite representation for an (text, image) pair. We then train a model called a transformer whose inputs are the entire stream of concatenated text and image tokens. The model learns to predict the next token in the sequence using earlier tokens.

Transformers are based on a mechanism called (self-)attention. In our case, this means that the model learns to give different weight to different previously generated tokens as it attempts to predict the next one. For example, if we were trying to predict the next word (“cold”) in the sentence “I need a jacket because it is ___”, the model would learn that it should give the word “jacket” more weight than the word “I.”

(DALL-E actually uses multi-head attention, which means it actually learns many weighting schemes at the same time for a given set of previously generated tokens and then makes a prediction using a combination of the outputs of all of those schemes.)

How to generate images

Once we’ve trained the transformer from the previous section, given a new prompt that we haven’t seen, we can:

  1. Encode it using the encoder we trained in the first step
  2. Generate (hopefully reasonable) image tokens one-by-one using our transformer. The first image token would be generated using all of the text tokens, the second image token would use all text tokens and the first image token, and so on. (Note: This is not exactly how it works, but it’s close enough for our purposes.)
  3. Use the decoder that we trained in the first step to translate those image tokens back into an image

Conclusion

I glossed over many of the mathematical and computational details of how this works (I don’t even have my head around all of them!), but the goal was to demystify one approach used to build a(n awesome) generative AI system that has taken the internet by storm. Hopefully you enjoyed!

This post is licensed under CC BY 4.0 by the author.