ngushchin / entropicotbenchmark Goto Github PK
View Code? Open in Web Editor NEWEntropic Optimal Transport Benchmark (NeurIPS 2023).
Home Page: https://arxiv.org/abs/2306.10161
License: MIT License
Entropic Optimal Transport Benchmark (NeurIPS 2023).
Home Page: https://arxiv.org/abs/2306.10161
License: MIT License
Hi,
Thank you very much for this great effort, which the community really would profit from!
While trying to set up this benchmark, I faced a few issues, which make it hard to use this benchmark:
requirements.txt
file? I tried a few, which resulted in some packages not being available (e.g. the pinned numpy version). In the end, I figure out that python==3.9 works. Is this the version you also use?torch
backend.from eot_benchmark import ...
, for me only works when replacing this by from benchmark import ...
in cell 1 of https://github.com/ngushchin/EntropicOTBenchmark/blob/main/notebooks/mixtures_benchmark_visualization_eot.ipynb. Similar issues arise with relative imports/paths appended with sys.path.append
True Wasserstein-2 Distance: 2.920156971631272
, and Variance of X: 4.264332
, Variance of Y: 4.24657
, which significantly differs from the values provided in the notebook.I tried to replicate the notebook above using the following steps:
import sys
sys.path.append(".../EntropicOTBenchmark/benchmark")
sys.path.append(".../EntropicOTBenchmark")
sys.path.append(".../EntropicOTBenchmark/baselines/EntropicNeuralOptimalTransport")
import matplotlib
import math
import gc
#import wandb
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy import linalg
import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
torch.random.manual_seed(0xBADBEEF)
from scipy.stats import ortho_group
from scipy.linalg import sqrtm, inv
# from src.icnn import DenseICNN
# from src.tools import compute_l1_norm, ewma
#from src import distributions
#from src.tools import unfreeze, freeze
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import PCA
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) +
np.trace(sigma2) - 2 * tr_covmean)
DIM = 16 # 2,16,64,128
EPS = 0.1 # 0.1, 1, 10
SEED = 987
GPU_DEVICE = 0
BATCH_SIZE=1024
NUM_SAMPLES_PLOT=1000
SELECTED_IDX = [233,43,12,62,555]
SELECTED_ITERS = 16
NUM_SAMPLES_METRICS=1000
torch.manual_seed(0), np.random.seed(0)
from src import distributions
OUTPUT_SEED = 0xC0FFEE
np.random.seed(OUTPUT_SEED)
torch.manual_seed(OUTPUT_SEED)
mu_0 = np.zeros(DIM)
mu_T = np.zeros(DIM)
mu_optimal_plan = np.zeros(2*DIM)
rotation_Y = ortho_group.rvs(DIM)
weight_Y = rotation_Y @ np.diag(np.exp(np.linspace(np.log(0.5), np.log(2), DIM)))
sigma_Y = weight_Y @ weight_Y.T
Y_sampler = distributions.LinearTransformer(distributions.StandartNormalSampler(dim=DIM), weight_Y, bias=None)
rotation_X = ortho_group.rvs(DIM)
weight_X = rotation_X @ np.diag(np.exp(np.linspace(np.log(0.5), np.log(2), DIM)))
sigma_X = weight_X @ weight_X.T
X_sampler = distributions.LinearTransformer(distributions.StandartNormalSampler(dim=DIM), weight_X, bias=None)
BW = calculate_frechet_distance(np.zeros(DIM), sigma_X, np.zeros(DIM), sigma_Y) / 2
print('True Wasserstein-2 Distance: ', BW)
X = X_sampler.sample(100000).cpu().detach().numpy()
Var_X = np.sum(np.var(X, axis=0))
print('Variance of X:', Var_X)
Y = Y_sampler.sample(100000).cpu().detach().numpy()
Var_Y = np.sum(np.var(Y, axis=0))
print('Variance of Y:', np.sum(Var_Y))
torch.cuda.empty_cache()
Thank you so much for this!
Best,
Dominik
In the EntropicOTBenchmark/baselines/EntropicNeuralOptimalTransport/notebooks/High_dimensionsal_gaussians.ipynb, I believe that the closed formula entropic optimal transport plan (EOT) between Gaussians measures is not correct.
It does not coincide with that proved in [Janati+2020] and coded in https://github.com/hichamjanati/Entropic-OT-gaussians/blob/master/closed_forms.py, see function closed_form_balanced
.
Indeed, [Janati+2020] states that for
where
In cell [26]
of Sec. 4 of the High_dimensionsal_gaussians.ipynb notebook, there seems to be an error in the variable change
get_D_sigma
returns 0.5*(covariance_0_sqrt@D_sigma@covariance_0_sqrt_inv - epsilon*np.eye(shape))
.get_C_sigma
returnssymmetrize(sqrtm(4*covariance_0_sqrt@covariance_T@covariance_0_sqrt + (epsilon**2)*np.eye(shape)))
.However, if we rewrite
Therefore, epsilon**2
in get_D_sigma
should be replaced by epsilon**2/4
; and epsilon**2
in get_C_sigma
should be replaced by epsilon / 2
; i.e.
0.5*(covariance_0_sqrt@D_sigma@covariance_0_sqrt_inv - (epsilon / 2) *np.eye(shape))
.symmetrize(sqrtm(4*covariance_0_sqrt@covariance_T@covariance_0_sqrt + ((epsilon**2 / 4)*np.eye(shape)))
.This is a crucial problem because the way functions are currently coded results in the covariance matrix
to have significantly negative eigenvalues.
The README suggests to
pip install -r requirements.txt
but the file seems not to have been included in this repository.
Hello,
Thank you for putting this together. I am also looking into benchmarking and reproducibility. I can see ENOT implementation, where are the other implementations for FB-SB etc?
[1] Nikita Gushchin, Alexander Kolesov, Alexander Korotin, Dmitry Vetrov, and Evgeny Burnaev. Entropic neural optimal transport via diffusion processes. arXiv preprint arXiv:2211.01156, 2022.
Best,
James
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.