Optimal Transport and Sinkhorn Interation
Introduction
The Optimal Transport (OT) between two distributions is proportional to the minimum amount of work required to change one distribution into the other. OT has been widely applied in various machine learning tasks, such as generative adversarial networks (GANs) [2], few-shot learning [3], self-supervised learning [4], information retrieval [5], object detection [6], and multi-label classification [7]. In this article, we introduce OT between two discrete distributions, its solution method proposed in [1], and the corresponding gradient computation.
For ease of understanding, this article only covers the optimal transport problem in the discrete case. It is worth noting that many resources on optimal transport use different terminologies, such as optimal transport / Wasserstein distance / earth mover's distance / Sinkhorn distance. In the discrete case, optimal transport / Wasserstein distance / earth mover's distance can be considered equivalent, and Sinkhorn iteration is a fast iterative solution method for optimal transport, which will be introduced later. The mathematical symbols used in this article follow those in reference [1], with detailed explanations provided.
The earth mover's distance
Suppose there are
to represent the piles of soil and pits, and assume
obviously, the sum of each row/column of the transportation matrix
The set of all transportation plans
where
The optimal transport
where
In short, optimal transport aims to find the transportation plan
EMD as a measure of probability distribution distance
If
Why optimal transport?
So why do we compute optimal transport? If the goal is to measure the distance between probability distributions, there are many existing metrics available, such as the very simple KL divergence:
Besides reasons such as "unable to handle the case where the support sets of two distributions do not intersect"
and "asymmetric", an important reason is that this pointwise metric
does not consider the structural information within the distribution.
Structural information refers to the relationships within the distribution.
For example, in KL divergence,
Take the common classification task as an example. Classification tasks usually use cross-entropy loss to measure the distance between model predictions and sample labels. Cross-entropy loss is essentially calculating the KL divergence between the one-hot encoded labels and the model predictions. This pointwise loss function (whether cross-entropy or L2) cannot consider the correlation between different events within the distribution. For instance, misclassifying a "car" as a "truck" is obviously less severe than misclassifying a "car" as a "zebra." However, using KL divergence to measure, the loss for these two errors is the same.
Assume
Entropy Regularization in Optimal Transport
In reference [1], Cuturi proposed a fast solution method for the optimal transport problem. This method first introduces entropy regularization to smooth the feasible region of the original problem, then transforms the optimal transport problem into a matrix permutation problem, and finally uses the Sinkhorn iteration algorithm to obtain an approximate solution to the original problem.
The entropy of the joint distribution
After adding the entropy regularization constraint to the original problem in Equation
Equation
From the optimization perspective, the solution to the original problem in Equation
As
Why entropy regularization
The author of [1] lists two reasons for using entropy regularization:
-
The solution to the original linear programming problem in Equation
must be at a vertex of the feasible region , resulting in a sparse matrix . A sparse solution leads to an imbalanced transportation plan. Entropy regularization can make the transportation matrix more balanced, as discussed in Chapter 3 of reference [1]. -
More importantly, after adding entropy regularization, the original problem can be approximately solved using the Sinkhorn algorithm, significantly reducing computational cost. This is detailed in Chapter 4 of the original paper.
In summary, the solution to the entropy-regularized optimal transport problem is smoother and more balanced, and the Sinkhorn iterative algorithm can be used for fast computation.
Solving Optimal Transport with Sinkhorn Iteration
Assume vectors
Then solve using the Sinkhorn iteration algorithm:
where
The gradients of
With Equations
PyTorch Implementation
Here (vlkit.optimal_transport.sinkhorn) is a PyTorch implementation of the Sinkhorn algorithm, which we can use to compute and visualize the optimal transport between two distributions.
For example, generate two 1D Gaussian distributions as the source and target distributions:
import torch
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
from vlkit.optimal_transport import sinkhorn
# generate two gaussians as the source and target
def gaussian(mean=0, std=10, n=100):
d = (-(torch.arange(n) - mean)**2 / (2 * std**2)).exp()
d /= d.sum()
return d
n = 20
d1 = gaussian(mean=12, std=2, n=n)
d2 = gaussian(mean=6, std=4, n=n)
dist = (torch.arange(n).view(1, n) - torch.arange(n).view(n, 1)).abs().float()
dist /= dist.max()
# visualize distr
fig, axes = plt.subplots(1, 2, figsize=(9, 3))
axes[0].bar(torch.arange(n), d1)
axes[0].set_title('Source distribution')
axes[1].bar(torch.arange(n), d2)
axes[1].set_title('Target distribution')
plt.tight_layout()
Then compute the optimal transport using the sinkhorn
iteration and visualize the results.
T, u, v = sinkhorn(
r=d1.unsqueeze(dim=0),
c=d2.unsqueeze(dim=0),
reg=1e-2,
M=dist.unsqueeze(dim=0)
)
plt.figure(figsize=(10, 10))
gs = gridspec.GridSpec(3, 3)
ax1 = plt.subplot(gs[0, 1:3])
plt.bar(torch.arange(n), d2, label='Target distribution')
ax2 = plt.subplot(gs[1:, 0])
ax2.barh(torch.arange(n), d1, label='Source distribution')
plt.gca().invert_xaxis()
plt.gca().invert_yaxis()
plt.subplot(gs[1:3, 1:3], sharex=ax1, sharey=ax2)
plt.imshow(T.squeeze(dim=0))
plt.axis('off')
plt.tight_layout()
Additional material for further reading
Here are some additional materials on optimal transport for further reading:
- Machine Learning Tools (II): Notes of Optimal Transport
- Introduction to Optimal Transport
- Lecture 11.3: Discrete Optimal Transport (cont.) | Sinkhorn Iterations | CVF20 (YouTube video)
- Notes on Optimal Transport
- Python Optimal Transport Library
- PyTorch Wasserstein
@misc{zhao2020optimal,
title = {Optimal Transport and Sinkhorn Iteration},
author = {Kai Zhao},
year = 2020,
note = {\url{http://kaizhao.net/blog/ot}}
}
References
-
Cuturi, Marco. Sinkhorn distances: Lightspeed computation of optimal transport Advances in neural information processing systems 26 (2013): 2292-2300.
-
Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.
-
Zhang, Chi, et al. "DeepEMD: Few-Shot Image Classification With Differentiable Earth Mover's Distance and Structured Classifiers." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.
-
Liu, Songtao, Zeming Li, and Jian Sun. "Self-EMD: Self-Supervised Object Detection without ImageNet." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
-
Xie, Yujia, et al. "Differentiable top-k operator with optimal transport." arXiv preprint arXiv:2002.06504 (2020).
-
Ge, Zheng, et al. "OTA: Optimal Transport Assignment for Object Detection." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
-
Frogner, Charlie, et al. Learning with a Wasserstein loss. Advances in neural information processing systems (2015).