Optimal Transport and Sinkhorn Interation

Introduction

The Optimal Transport (OT) between two distributions is proportional to the minimum amount of work required to change one distribution into the other. OT has been widely applied in various machine learning tasks, such as generative adversarial networks (GANs) [2], few-shot learning [3], self-supervised learning [4], information retrieval [5], object detection [6], and multi-label classification [7]. In this article, we introduce OT between two discrete distributions, its solution method proposed in [1], and the corresponding gradient computation.

For ease of understanding, this article only covers the optimal transport problem in the discrete case. It is worth noting that many resources on optimal transport use different terminologies, such as optimal transport / Wasserstein distance / earth mover's distance / Sinkhorn distance. In the discrete case, optimal transport / Wasserstein distance / earth mover's distance can be considered equivalent, and Sinkhorn iteration is a fast iterative solution method for optimal transport, which will be introduced later. The mathematical symbols used in this article follow those in reference [1], with detailed explanations provided.

The earth mover's distance

Suppose there are piles of soil on the ground, with the -th pile containing amount of soil. In the meanwhile, there are pits, with the -th pit capable of holding amount of soil. We use

to represent the piles of soil and pits, and assume , meaning the total amount of soil exactly fills all the pits. Define as the cost matrix, where represents the cost of moving one unit of soil from the -th pile to the -th pit. Now, to transport soil from to , define as the transportation plan, where represents the amount of soil moved from to . The goal of optimal transport is to find an optimal transportation matrix that minimizes the total transportation cost.

obviously, the sum of each row/column of the transportation matrix should equal /, which is why the letters r(ow) and c(olumn) are used. Formally, satisfies the constraints:

The set of all transportation plans that satisfy the conditions is:

where represents a -dimensional vector of ones. It is worth mentioning that forms a polytope in high-dimensional space because all constraints in Equation are linear. In a plane, multiple linear constraints enclose a polygonal region, and in high-dimensional space, the set enclosed by linear constraints is a polytope. Here, the set represents all feasible transportation plans, which is also referred to as the "feasible set" in optimization.

The optimal transport is defined as "the transportation plan that minimizes the total transportation cost ":

where is the Frobenius inner product. The transportation plan that minimizes Equation is called the optimal transportation plan.

In short, optimal transport aims to find the transportation plan within the feasible region that minimizes Equation . Clearly, this is a linear programming problem, as both the constraints in Equation and the objective function in Equation are linear.

EMD as a measure of probability distribution distance

If and satisfy and , then and can be considered as two probability distributions, and can be viewed as the "joint probability distribution" with marginal distributions and . The optimal transport distance between and can be seen as the difference between the probability distributions. In many applications, the inputs and are normalized into probabilities (e.g., softmax in classification), and the optimal transport distance between the two probabilities is calculated. Therefore, optimal transport can also be used to measure the distance between probability distributions.

Why optimal transport?

So why do we compute optimal transport? If the goal is to measure the distance between probability distributions, there are many existing metrics available, such as the very simple KL divergence:

Besides reasons such as "unable to handle the case where the support sets of two distributions do not intersect" and "asymmetric", an important reason is that this pointwise metric does not consider the structural information within the distribution. Structural information refers to the relationships within the distribution. For example, in KL divergence, are calculated independently and then summed up, but in most cases, they are not independent.

Take the common classification task as an example. Classification tasks usually use cross-entropy loss to measure the distance between model predictions and sample labels. Cross-entropy loss is essentially calculating the KL divergence between the one-hot encoded labels and the model predictions. This pointwise loss function (whether cross-entropy or L2) cannot consider the correlation between different events within the distribution. For instance, misclassifying a "car" as a "truck" is obviously less severe than misclassifying a "car" as a "zebra." However, using KL divergence to measure, the loss for these two errors is the same.

Assume and are two discrete probability distributions, satisfying and , where and represent all-ones column vectors. We can use optimal transport to calculate the distance between them. The structural information within the probability distribution can be "embedded" into the distance metric through the cost matrix . Taking classification as an example, we can set to be much smaller than .

Entropy Regularization in Optimal Transport

In reference [1], Cuturi proposed a fast solution method for the optimal transport problem. This method first introduces entropy regularization to smooth the feasible region of the original problem, then transforms the optimal transport problem into a matrix permutation problem, and finally uses the Sinkhorn iteration algorithm to obtain an approximate solution to the original problem.

The entropy of the joint distribution (entropy) is:

After adding the entropy regularization constraint to the original problem in Equation , the optimization objective becomes:

Equation is referred to as "entropy-regularized optimal transport," and the corresponding optimal solution is denoted as . From the perspective of probability and information theory, the entropy of a uniform distribution is the highest, so entropy regularization will make the optimal transport matrix closer to a uniform distribution.

From the optimization perspective, the solution to the original problem in Equation must exist at a vertex of the polytope (this statement may not be rigorous, but when the dimensions of and are very high, the optimal solution must be at a vertex), so is a sparse matrix with most elements being zero. After adding entropy regularization, the original feasible region shrinks inward to a smooth , and the corresponding optimal solution is no longer a sparse matrix.

The feasible regions and contain all points that satisfy the constraints. is a polytope in high-dimensional space, and after adding entropy regularization, the feasible region becomes a smooth region. The blue dashed lines represent the contour lines of the objective function, and the distance matrix determines the direction of the optimal solution. This figure is inspired by Fig.1 in reference [1].

As , approaches , and the optimal solution approaches . As , shrinks to the point , and the optimal solution is .

Why entropy regularization

The author of [1] lists two reasons for using entropy regularization:

  1. The solution to the original linear programming problem in Equation must be at a vertex of the feasible region , resulting in a sparse matrix . A sparse solution leads to an imbalanced transportation plan. Entropy regularization can make the transportation matrix more balanced, as discussed in Chapter 3 of reference [1].

  2. More importantly, after adding entropy regularization, the original problem can be approximately solved using the Sinkhorn algorithm, significantly reducing computational cost. This is detailed in Chapter 4 of the original paper.

In summary, the solution to the entropy-regularized optimal transport problem is smoother and more balanced, and the Sinkhorn iterative algorithm can be used for fast computation.

Solving Optimal Transport with Sinkhorn Iteration

Assume vectors and , with initial values

Then solve using the Sinkhorn iteration algorithm:

where (refer to Lemma 2 in [1]). After the iteration converges, the optimal transport matrix and the corresponding minimal transport cost can be given by:

The gradients of with respect to and are:

With Equations and , the optimal transport distance can be used as a loss function to train neural networks.

PyTorch Implementation

Here (vlkit.optimal_transport.sinkhorn) is a PyTorch implementation of the Sinkhorn algorithm, which we can use to compute and visualize the optimal transport between two distributions.

For example, generate two 1D Gaussian distributions as the source and target distributions:

import torch
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
from vlkit.optimal_transport import sinkhorn

# generate two gaussians as the source and target
def gaussian(mean=0, std=10, n=100):
    d = (-(torch.arange(n) - mean)**2 / (2 * std**2)).exp()
    d /= d.sum()
    return d

n = 20
d1 = gaussian(mean=12, std=2, n=n)
d2 = gaussian(mean=6, std=4, n=n)

dist = (torch.arange(n).view(1, n) - torch.arange(n).view(n, 1)).abs().float()
dist /= dist.max()

# visualize distr
fig, axes = plt.subplots(1, 2, figsize=(9, 3))
axes[0].bar(torch.arange(n), d1)
axes[0].set_title('Source distribution')
axes[1].bar(torch.arange(n), d2)
axes[1].set_title('Target distribution')
plt.tight_layout()

Then compute the optimal transport using the sinkhorn iteration and visualize the results.

T, u, v = sinkhorn(
    r=d1.unsqueeze(dim=0),
    c=d2.unsqueeze(dim=0),
    reg=1e-2,
    M=dist.unsqueeze(dim=0)
)
plt.figure(figsize=(10, 10))
gs = gridspec.GridSpec(3, 3)

ax1 = plt.subplot(gs[0, 1:3])
plt.bar(torch.arange(n), d2, label='Target distribution')

ax2 = plt.subplot(gs[1:, 0])
ax2.barh(torch.arange(n), d1, label='Source distribution')

plt.gca().invert_xaxis()
plt.gca().invert_yaxis()

plt.subplot(gs[1:3, 1:3], sharex=ax1, sharey=ax2)
plt.imshow(T.squeeze(dim=0))
plt.axis('off')

plt.tight_layout()

Additional material for further reading

Here are some additional materials on optimal transport for further reading:

@misc{zhao2020optimal,
    title   = {Optimal Transport and Sinkhorn Iteration},
    author  = {Kai Zhao},
    year    = 2020,
    note    = {\url{http://kaizhao.net/blog/ot}}
}

References

  1. Cuturi, Marco. Sinkhorn distances: Lightspeed computation of optimal transport Advances in neural information processing systems 26 (2013): 2292-2300.

  2. Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.

  3. Zhang, Chi, et al. "DeepEMD: Few-Shot Image Classification With Differentiable Earth Mover's Distance and Structured Classifiers." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.

  4. Liu, Songtao, Zeming Li, and Jian Sun. "Self-EMD: Self-Supervised Object Detection without ImageNet." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.

  5. Xie, Yujia, et al. "Differentiable top-k operator with optimal transport." arXiv preprint arXiv:2002.06504 (2020).

  6. Ge, Zheng, et al. "OTA: Optimal Transport Assignment for Object Detection." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.

  7. Frogner, Charlie, et al. Learning with a Wasserstein loss. Advances in neural information processing systems (2015).