An introduction to Flow Matching
This write-up is primarily based on the excellent lecture-notes, which I encourage the readers to check out. I have tried to condense the content and simplify some derivations. I expect to follow up with a separate write-up on Diffusion and SDE.
We want to model some data generating distribution $p_{data}$ using samples $(y_1, y_2, …, y_n)$ drawn from this distribution. We will do so by defining a map from samples from some simple distribution $p_{\text{init}}$ (for example, a gaussian distribution with zero mean and unit covariance matrix) to samples from our data generating distribution $p_{data}$. Towards this, we will define a family of intermediate variables $x_t$ drawn from distributions $p_t$ where $p_0 = p_{\text{init}}$ and $p_1 = p_{data}$ for $t \in [0, 1]$. Note that $x_t$ is defined to have the same dimensionality as the data distribution. The trajectory of $x_t$ is defined using the ODE:
\[\frac{d}{dt}x_t = u_t(x_t)\]where we want to learn $u_t$ such that drawing different samples $x_0$ from $p_{\text{init}}$ and following this ODE from $t = 0$ to $t = 1$ will lead to different samples from $p_{data}$.
Note that $u_t$ is time dependent and learns a function with same input/output dimensionality as the data. Thus, we can interpret it as learning a vector field in the data space which tells at each time how to nudge each point such that starting from points drawn from $p_{\text{init}}$ at $t = 0$ and evolving it to $t = 1$, we end up with different samples from $p_{data}$.
Here $u_t$ is said to describe a flow.
If we have such a $u_t$ we can simulate it to draw samples from $p_{data}$, for example, using the Euler Method with:
\[\begin{aligned} x_0 &\sim p_{\text{init}}& \\ x_{t+h} &= x_t + h \cdot u_t(x_t) \quad\ \ t = [0, h, 2h, ..., (n-1)h] \end{aligned}\]where $n$ is number of steps and $h=1/n$ is the step-size.
To learn such a $u_t$ we will try to explicitly construct a target $u_t$ which has this desired property.
First, instead of constructing flow for the whole data, we will construct flow for any single data-point $y$: $u_{t, y}$ such that evolving this flow from $t = 0$ to $t = 1$ will land us on $y$ for any $x_0$ drawn from $p_{\text{init}}$.
\[\begin{aligned} x_0 &\sim p_0 \\ \frac{d}{dt}x_t \quad&= u_{t,y}(x_t) \\ x_1 \quad&= y \end{aligned}\]To construct this flow we can construct a set of conditional interpolating distributions $p(x_t \mid y)$ which describes the distribution of $x_t$ at each t. We have $p(x_0 \mid y)$ = $p_{\text{init}}$ and since $p(x_t \mid y)$ is concentrated on a single point y at $t = 1$, we can model this as a Dirac-Delta distribution $\delta(y)$. It is called conditional interpolating distribution because for a given $y$ it interpolates between $p_{\text{init}}$ and the sample $y$. We can use this later to construct a marginal interpolating distribution and the corresponding marginal flow which will interpolate between $p_{\text{init}}$ and $p_{data}$ as required.
We can define an example conditional interpolating distribution using gaussian distribution as:
\[p(x_t|y) = \mathcal{N}(\alpha_t\ y, \beta_t\ I)\]where $\alpha_t$ and $\beta_t$ are defined such that
\[\begin{align*} \alpha_0 = 0 \quad \beta_0 = 1 \\ \alpha_1 = 1 \quad \beta_1 = 0 \end{align*}\]Verify for yourself that $p_{\text{init}} := p(x_0 \mid y) = \mathcal{N}(0, I)$ and $p(x_1 \mid y) = \delta(y)$.
We can now construct a conditional trajectory $x_t$ that each $x_0 \sim p_{\text{init}}$ takes such that $x_t$ will follow the conditional interpolating distribution $p(x_t \mid y)$ by defining the trajectory as:
for each $x_0$
Since $x_0$ itself is normally distributed with zero mean and unit co-variance, you can verify for yourself that the defined $x_t$ will follow the required Normal distribution outlined above.
We simply have to take time derivative of this $x_t$ to get the conditional flow
Define marginal interpolating distribution of random variable $x_t$ as
\[p(x_t) = \int p(x_t|y)p_{data}(y) \, dy\]i.e. to sample $x_t$, we sample $y$ from $p_{data}$, then we sample $x_t$ from $p(x_t \mid y)$. The resulting distribution of $x_t$ is defined as the marginal interpolating distribution.
We can verify that $p(x_0) = p_{\text{init}}$ and $p(x_1) = p_{data}$ using the definition of Dirac-Delta function. Thus, if we can construct a flow to follow this distribution, we can sample from the data distribution by simulating the ODE as outlined in the section Sampling.
Important Identity
It turns out that the marginal flow of this distribution can be defined in terms of conditional flow as:
where $p(y \mid x_t)$ is the posterior distribution of $y$ given $x_t$, which using the Bayes rule, can be expressed as:
\[p(y|x_t) = \frac{p(x_t|y)\cdot p_{data}(y)}{p(x_t)}\]For intuition, if we think about finitely many $y’s$, the marginal flow for $x_t$ can be obtained by averaging the conditional flow needed to reach $y$ over all the $y’s$ that $x_t$ could have come from. I hope this is somewhat intuitive, you can find the proof in the notes linked at the beginning.
For any given $y$ we have seen how we can derive a conditional flow we want our model to approximate. For example, in the gaussian case we derived:
\[u_{t,y}^{target}(x_t) = \dot{\alpha_t}y+\frac{\dot{\beta_t}}{\beta_t}(x_t - \alpha_ty)\]And the marginal flow we want to actually learn has the form:
\[u_t^{target}(x_t) = \mathbb{E}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t)]\]If we could calculate this, we could define the flow matching loss as:
\[\begin{aligned} \mathcal{L}(\theta) &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)} \left\| u_t^{\text{target}}(x_t) - u_t^\theta(x_t) \right\|^2 \\ &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)} \left\| \mathbb{E}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t)] - u_t^\theta(x_t) \right\|^2 \end{aligned}\]That is for each timestep and $x_t$, try to match flow for $x_t$ against the target flow which is average of the conditional flow for all the $y$ the $x_t$ could have come from. Since we cannot easily sample from the posterior to estimate the inner expectation, we will apply the following trick to rewrite the loss as:
\[\begin{aligned} \mathcal{L}(\theta) &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)} \left\| \mathbb{E}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t)] - u_t^\theta(x_t) \right\|^2 \\ &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)} \left\| \mathbb{E}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t)] - \mathbb{E}_{y \sim p(y|x_t)}[u_t^\theta(x_t)] \right\|^2 \end{aligned}\]where, since the second term does not depend on $y$, rewriting it as expectation over $y$ makes no difference. Now the loss becomes,
\[\begin{aligned} \mathcal{L}(\theta) &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)} \left\| \mathbb{E}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t) - u_t^\theta(x_t)] \right\|^2 \\ &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)}\left[ \mathbb{E}_{y \sim p(y|x_t)}\left\|u_{t, y}^{target}(x_t) - u_t^\theta(x_t) \right\|^2 - \operatorname{Var}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t) - u_t^\theta(x_t)]\right]\\ \end{aligned}\]Here we used the identity $\operatorname{Var}[x] = \mathbb{E}[x^2]-\mathbb{E}[x]^2$ where $\operatorname{Var}[\cdot]$ denotes the element-wise variance of the vector-valued expression, i.e., the variance is computed independently for each coordinate of the vector. Furthermore, since the second term inside the variance: $u_{t}^{\theta}(x_t)$ does not depend on $y$, it works out as:
\[\operatorname{Var}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t) - u_t^\theta(x_t)] = \operatorname{Var}_{y \sim p(y|x_t)}[u_{t, y}^{target}(x_t)]\]which does not depend on the parameters $\theta$. Thus, we can omit this term and write the new loss with the squared term now inside the expectation as:
\[\begin{aligned} \mathcal{L}(\theta) &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t)}\left[ \mathbb{E}_{y \sim p(y|x_t)} \left\| u_{t, y}^{target}(x_t) - u_t^\theta(x_t) \right\|^2 \right]\\ &= \mathbb{E}_{t \sim [0,1],\ x_t \sim p(x_t),\ y \sim p(y|x_t)} \left\| u_{t, y}^{target}(x_t) - u_t^\theta(x_t) \right\|^2 \\ &= \mathbb{E}_{t \sim [0,1],\ y \sim p_{data}(y),\ x_t \sim p(x_t|y)} \left\| u_{t, y}^{target}(x_t) - u_t^\theta(x_t) \right\|^2 \\ \end{aligned}\]where, for the final step, we rewrite the expectation over the joint distribution $p(x_t,y)$ as one over $y∼p(y)$, followed by $x_t∼p(x_t \mid y)$.
Notice that in the final form we don’t need to sample from the posterior distribution. Thus, we can train our model with minibatch gradient descent. We summarize the training and sampling algorithms in the following section.
Algorithm: Flow Matching Training
Given: Data distribution $p_{data}$, model $u_t^\theta$, and $p(x_t \mid y)$
For each mini-batch of data do:
1. Sample $t \sim \left[0,1\right]$
2. Sample $x_t \sim p(x_t \mid y)$
3. Compute loss $\mathcal{L}(\theta) = \lVert u_{t, y}^{target}(x_t) - u_t^\theta(x_t) \rVert^2$
4. Update model parameters $\theta$ using gradient descent on loss $\mathcal{L}(\theta)$
Algorithm: Sampling from Flow Model
Given: Model $u_t^\theta$, number of steps $n$
1. Set $t = 0$
2. Set step-size $h = \frac{1}{n}$
3. Sample $x_0 \sim p_{\text{init}}$
4. for $i = 1, …, n$ do
a. $x_{t+h} = x_t + h \cdot u_t(x_t)$
b. $t = t + h$
5. end for
6. return $x_1$