diff --git a/RELEASES.md b/RELEASES.md index f0fae4fa0..470121d94 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -26,7 +26,10 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver a callable, or a no-op (PR #808) - Add optional `scaler` parameter to `sliced_wasserstein_distance` and `max_sliced_wasserstein_distance` (PR #808) - Add a numerically stable log-domain solver for entropic partial Wasserstein, selectable via the new `method` parameter of `entropic_partial_wasserstein` (`method='sinkhorn_log'`) or directly through `entropic_partial_wasserstein_logscale` (Issue #723) -- Add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920), implemented in `ot.sgot` (PR #792) +- Add cost functions between linear operators following + [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920), + implemented in `ot.sgot` (PR #792) +- Add batch FUGW loss to `ot.batch` and fix issues in some default parameters in the batch module (PR #775) - Build wheels on ubuntu ARM to avoid QEMU emulation (PR #818) #### Closed issues diff --git a/examples/backends/plot_gradient_descent.py b/examples/backends/plot_gradient_descent.py new file mode 100644 index 000000000..efe02cbce --- /dev/null +++ b/examples/backends/plot_gradient_descent.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- +r""" +=============================================================================== +Solve Fused Unbalanced Gromov Wasserstein with Adam +=============================================================================== + +Since the FUGW loss is differentiable, it can be minimized with first-order optimization. +We show how to do this with the `loss_fugw_batch` function and compare the results with +the dedicated FUGW solver `fused_unbalanced_gromov_wasserstein`. +""" + +# Author: Rémi Flamary +# Sonia Mazelet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pylab as pl +import torch +from time import perf_counter +import ot +from ot.batch._quadratic import loss_fugw_batch, tensor_batch +from ot.gromov import fused_unbalanced_gromov_wasserstein +from sklearn.manifold import MDS + + +# %% +# Generation of source and target graphs +# ---------------- + +rng = np.random.RandomState(42) + + +def get_sbm(n, nc, ratio, P): + nbpc = np.round(n * ratio).astype(int) + n = np.sum(nbpc) + C = np.zeros((n, n)) + for c1 in range(nc): + for c2 in range(c1 + 1): + if c1 == c2: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])): + for j in range(np.sum(nbpc[:c2]), i): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + else: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])): + for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[: c2 + 1])): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + + return C + C.T + + +def plot_graph(x, C, color="C0", s=100): + for j in range(C.shape[0]): + for i in range(j): + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k") + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k") + + +def get_sbm_labels(n, ratio): + nbpc = np.round(n * ratio).astype(int) + return np.concatenate( + [np.full(count, label, dtype=int) for label, count in enumerate(nbpc)] + ) + + +def get_noisy_one_hot(labels, n_classes, noise_level=0.1): + x = np.eye(n_classes)[labels] + x += noise_level * rng.randn(*x.shape) + return x + + +n1 = 15 +n2 = 10 +nc1 = 3 +nc2 = 2 +ratio1 = np.array([0.33, 0.33, 0.33]) +ratio2 = np.array([0.5, 0.5]) + +P1 = np.array([[0.8, 0.03, 0.0], [0.08, 0.8, 0.03], [0.0, 0.08, 0.8]]) +P2 = np.array(0.8 * np.eye(2) + 0.01 * np.ones((2, 2))) +C1 = get_sbm(n1, nc1, ratio1, P1) +C2 = get_sbm(n2, nc2, ratio2, P2) +labels1 = get_sbm_labels(n1, ratio1) +labels2 = get_sbm_labels(n2, ratio2) + +# Use noisy one-hot encodings of the SBM classes as node features. +feature_dim = max(nc1, nc2) +x1 = get_noisy_one_hot(labels1, feature_dim) +x2 = get_noisy_one_hot(labels2, feature_dim) +all_features = np.vstack([x1, x2]) +feature_min = all_features[:, :3].min(axis=0, keepdims=True) +feature_max = all_features[:, :3].max(axis=0, keepdims=True) + +# get 2d positions for visualization +pos1 = MDS(dissimilarity="precomputed", random_state=0, n_init=1).fit_transform(1 - C1) +pos2 = MDS(dissimilarity="precomputed", random_state=0, n_init=1).fit_transform(1 - C2) + +colors1 = np.clip( + (x1 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0 +) +colors2 = np.clip( + (x2 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0 +) + + +pl.figure(1, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(pos1, C1, color=colors1) +pl.title("SBM source graph") +pl.axis("off") +pl.subplot(1, 2, 2) +plot_graph(pos2, C2, color=colors2) +pl.title("SBM target graph") +_ = pl.axis("off") + + +# %% +# Solve FUGW with Adam +# ---------------- + +# Even though `loss_fugw_batch` supports batches of problems, we use a +# batch of size 1 here for clarity. + +a = ot.unif(C1.shape[0]) +b = ot.unif(C2.shape[0]) +M = ot.dist(x1, x2) +M /= M.max() + +a_torch = torch.tensor(a[None, :]) +b_torch = torch.tensor(b[None, :]) +C1_torch = torch.tensor(C1[None, :, :]) +C2_torch = torch.tensor(C2[None, :, :]) +M_torch = torch.tensor(M[None, :, :]) +L = tensor_batch(a_torch, b_torch, C1_torch, C2_torch, loss="sqeuclidean") + +alpha_batch = 0.5 +# `loss_fugw_batch` uses alpha as: alpha * quadratic + (1 - alpha) * linear +# while the dedicated solver uses alpha as the coefficient of the linear term. +alpha_bcd = (1 - alpha_batch) / alpha_batch + +reg_marginals_batch = 0.5 +reg_marginals_bcd = reg_marginals_batch / alpha_batch +lr = 5e-2 +nb_iter_max = 1500 +tol = 1e-7 + +T0_torch = a_torch[:, :, None] * b_torch[:, None, :] +T_torch = torch.log(torch.expm1(T0_torch)).clone().requires_grad_(True) +optimizer = torch.optim.Adam([T_torch], lr=lr) +loss_iter = [] +mass_iter = [] +previous_plan_torch = None + +tic = perf_counter() +for i in range(nb_iter_max): + optimizer.zero_grad() + # Positive transport plan parameterized as log(1 + exp(T)). + plan_torch = torch.nn.functional.softplus(T_torch) + loss = loss_fugw_batch( + a_torch, + b_torch, + L, + M_torch, + plan_torch, + alpha=alpha_batch, + reg_marginals=reg_marginals_batch, + divergence="kl", + recompute_const=True, + )[0] + + loss_iter.append(float(loss.detach())) + mass_iter.append(float(plan_torch.detach().sum())) + if previous_plan_torch is not None: + err = float(torch.sum(torch.abs(plan_torch.detach() - previous_plan_torch))) + if err < tol: + break + previous_plan_torch = plan_torch.detach().clone() + loss.backward() + optimizer.step() +time_adam = perf_counter() - tic + +T_adam = torch.nn.functional.softplus(T_torch).detach().cpu().numpy()[0] + +pl.figure(2, (10, 4)) +pl.clf() +pl.subplot(1, 2, 1) +pl.plot(loss_iter) +pl.grid() +pl.title("FUGW loss along iterations") +pl.xlabel("Iterations") +pl.subplot(1, 2, 2) +pl.plot(mass_iter) +pl.grid() +pl.title("Transport mass") +_ = pl.xlabel("Iterations") + + +# %% +# Compare with the dedicated FUGW solver +# ------------------------------------- +# +# The dedicated solver uses a block coordinate descent (BCD) scheme. We compare +# the coupling it returns with the one obtained by direct Adam minimization of +# `loss_fugw_batch`. + + +def evaluate_batch_fugw_loss(plan): + plan_torch = torch.tensor(plan[None, :, :], dtype=M_torch.dtype) + loss = loss_fugw_batch( + a_torch, + b_torch, + L, + M_torch, + plan_torch, + alpha=alpha_batch, + reg_marginals=reg_marginals_batch, + divergence="kl", + recompute_const=True, + )[0] + return float(loss.detach()) + + +tic = perf_counter() +T_bcd, _, log = fused_unbalanced_gromov_wasserstein( + C1, + C2, + wx=a, + wy=b, + reg_marginals=reg_marginals_bcd, + divergence="kl", + unbalanced_solver="mm", + alpha=alpha_bcd, + M=M, + init_pi=np.outer(a, b), + max_iter=200, + tol=tol, + max_iter_ot=200, + tol_ot=1e-7, + log=True, +) +time_bcd = perf_counter() - tic + +loss_adam_final = evaluate_batch_fugw_loss(T_adam) +loss_bcd_final = evaluate_batch_fugw_loss(T_bcd) + + +# %% +# Visualize the learned couplings +# ------------------------------- +# We visualize the couplings obtained by both methods to compare them. On this example, both methods recover similar couplings, +# but direct minimization reaches a lower `loss_fugw_batch` value at the cost +# of a longer runtime. + +pl.figure(3, (10, 4)) +pl.clf() +pl.subplot(1, 2, 1) +pl.imshow(T_adam, interpolation="nearest") +pl.title( + f"Coupling from direct minimization\nloss={loss_adam_final:.3f}, time={time_adam:.2f}s" +) +pl.xlabel("Target nodes") +pl.ylabel("Source nodes") +pl.colorbar() +pl.subplot(1, 2, 2) +pl.imshow(T_bcd, interpolation="nearest") +pl.title(f"Coupling from BCD solver\nloss={loss_bcd_final:.3f}, time={time_bcd:.2f}s") +pl.xlabel("Target nodes") +pl.ylabel("Source nodes") +_ = pl.colorbar() diff --git a/ot/batch/_linear.py b/ot/batch/_linear.py index a63fcb404..1a9ec1955 100644 --- a/ot/batch/_linear.py +++ b/ot/batch/_linear.py @@ -147,7 +147,7 @@ def loss_linear_batch(M, T, nx=None): return nx.sum(M * T, axis=(1, 2)) -def loss_linear_samples_batch(X, Y, T, metric="l2"): +def loss_linear_samples_batch(X, Y, T, metric="sqeuclidean"): r"""Computes the linear optimal transport loss given samples and transport plan. This is the equivalent of calling `dist_batch` and then `loss_linear_batch`. diff --git a/ot/batch/_quadratic.py b/ot/batch/_quadratic.py index 0da4b8962..5b398fd80 100644 --- a/ot/batch/_quadratic.py +++ b/ot/batch/_quadratic.py @@ -10,8 +10,9 @@ from ..utils import OTResult from ot.backend import get_backend -from ot.batch._linear import loss_linear_batch +from ot.batch._linear import loss_linear_batch, loss_linear_samples_batch from ot.batch._utils import bmv, bop, bregman_log_projection_batch +from ot.utils import list_to_array def tensor_batch( @@ -152,6 +153,78 @@ def h2(C2): return compute_tensor_batch(f1, f2, h1, h2, a, b, C1, C2, symmetric=symmetric) +def div_between_product_batch(mu, nu, alpha, beta, divergence, nx=None): + r"""Fast computation of the Bregman divergence between batches of product measures. + Only support for Kullback-Leibler and half-squared L2 divergences. + + For half-squared L2 divergence: + + .. math:: + \frac{1}{2} || \mu \otimes \nu, \alpha \otimes \beta ||^2 + = \frac{1}{2} \Big[ ||\alpha||^2 ||\beta||^2 + ||\mu||^2 ||\nu||^2 - 2 \langle \alpha, \mu \rangle \langle \beta, \nu \rangle \Big] + + For Kullback-Leibler divergence: + + .. math:: + KL(\mu \otimes \nu, \alpha \otimes \beta) + = m(\mu) * KL(\nu, \beta) + m(\nu) * KL(\mu, \alpha) + (m(\mu) - m(\alpha)) * (m(\nu) - m(\beta)) + + where: + + - :math:`\mu` and :math:`\alpha` are two measures having the same shape. + - :math:`\nu` and :math:`\beta` are two measures having the same shape. + - :math:`m` denotes the mass of the measure + + Parameters + ---------- + mu : array-like, shape (B, ...) + First factor of each product measure in the batch. + nu : array-like, shape (B, ...) + Second factor of each product measure in the batch. + alpha : array-like, shape (B, ...) + Reference factor with the same shape as `mu`. + beta : array-like, shape (B, ...) + Reference factor with the same shape as `nu`. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ---------- + Bregman divergence between two product measures for each problem in the batch. + """ + + if nx is None: + nx = get_backend(mu, nu, alpha, beta) + + axis_mu = tuple(range(1, mu.ndim)) if mu.ndim > 1 else 0 + axis_nu = tuple(range(1, nu.ndim)) if nu.ndim > 1 else 0 + axis_alpha = tuple(range(1, alpha.ndim)) if alpha.ndim > 1 else 0 + axis_beta = tuple(range(1, beta.ndim)) if beta.ndim > 1 else 0 + + if divergence == "kl": + m_mu = nx.sum(mu, axis=axis_mu) + m_nu = nx.sum(nu, axis=axis_nu) + m_alpha = nx.sum(alpha, axis=axis_alpha) + m_beta = nx.sum(beta, axis=axis_beta) + const = (m_mu - m_alpha) * (m_nu - m_beta) + res = ( + m_nu * nx.kl_div(mu, alpha, mass=True, axis=axis_mu) + + m_mu * nx.kl_div(nu, beta, mass=True, axis=axis_nu) + + const + ) + + elif divergence == "l2": + res = ( + nx.sum(alpha**2, axis=axis_alpha) * nx.sum(beta**2, axis=axis_beta) + - 2 * nx.sum(alpha * mu, axis=axis_mu) * nx.sum(beta * nu, axis=axis_nu) + + nx.sum(mu**2, axis=axis_mu) * nx.sum(nu**2, axis=axis_nu) + ) / 2 + + return res + + def loss_quadratic_batch(L, T, recompute_const=False, symmetric=True, nx=None): r""" Computes the gromov-wasserstein cost given a cost tensor and transport plan. Batched version. @@ -205,7 +278,7 @@ def loss_quadratic_samples_batch( C2, T, loss="sqeuclidean", - symmetric=None, + symmetric=True, nx=None, logits=None, recompute_const=False, @@ -266,6 +339,206 @@ def loss_quadratic_samples_batch( ) +def loss_fugw_batch( + a, + b, + L, + M, + T, + alpha=0.5, + reg_marginals=1, + symmetric=True, + divergence="kl", + recompute_const=True, + nx=None, +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + a : array-like, shape (B, n) + Source distributions. + b : array-like, shape (B, m) + Target distributions. + L : dict + Cost tensor as returned by `tensor_batch`. + M : array-like, shape (B, n, m) + Cost matrix between features across domains. + T : array-like, shape (B, n, m) + Transport plan. + alpha : float, array-like or list (B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float array-like or list(B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + recompute_const : bool, optional + Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints. + nx : module, optional + Backend to use. Default is None. + """ + if nx is None: + nx = get_backend(T) + + B = T.shape[0] + + if isinstance(alpha, list): + alpha = list_to_array(alpha, nx=nx) + + if isinstance(reg_marginals, list): + reg_marginals = list_to_array(reg_marginals, nx=nx) + + if hasattr(alpha, "ndim") and alpha.ndim > 0: + if alpha.ndim != 1 or alpha.shape[0] != B: + raise ValueError( + f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}" + ) + + if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0: + if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B: + raise ValueError( + f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}" + ) + + quadratic = loss_quadratic_batch( + L, T, recompute_const=recompute_const, symmetric=symmetric, nx=nx + ) + + linear = loss_linear_batch(M, T, nx=nx) + + T1 = nx.sum(T, 2) + T2 = nx.sum(T, 1) + unbalanced = div_between_product_batch( + T1, + T2, + a, + b, + divergence=divergence, + nx=nx, + ) + + return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced + + +def loss_fugw_samples_batch( + a, + b, + C1, + C2, + X, + Y, + T, + alpha=0.5, + reg_marginals=1, + symmetric=True, + divergence="kl", + recompute_const=True, + metric_linear="sqeuclidean", + metric_quadratic="sqeuclidean", + logits=None, + nx=None, +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (quadratic term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + a : array-like, shape (B, n) + Source distributions. + b : array-like, shape (B, m) + Target distributions. + C1 : array-like, shape (B, n, n) or (B, n, n, d) + Source cost matrices for the quadratic term. + C2 : array-like, shape (B, m, m) or (B, n, n, d) + Target cost matrices for the quadratic term. + X : array-like, shape (B, n, d) + Samples from source distribution for the linear term + Y : array-like, shape (B, m, d) + Samples from target distribution for the linear term + T : array-like, shape (B, n, m) + Transport plan. + alpha : float or array-like or list(B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float or array-like or list(B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + recompute_const : bool, optional + Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints. + metric_linear : str, optional + Metric for the linear term, 'sqeuclidean', 'euclidean', 'minkowski' or 'kl' + metric_quadratic : str, optional + Metric to use for the quadratic term. Supported values: 'sqeuclidean', 'kl'. + Default is 'sqeuclidean'. + logits : bool, optional + For KL divergence, whether inputs are logits (unnormalized log probabilities). + If True, inputs are treated as logits. Default is None. + nx : module, optional + Backend to use. Default is None. + """ + if nx is None: + nx = get_backend(T) + + B = T.shape[0] + + if isinstance(alpha, list): + alpha = list_to_array(alpha, nx=nx) + + if isinstance(reg_marginals, list): + reg_marginals = list_to_array(reg_marginals, nx=nx) + + if hasattr(alpha, "ndim") and alpha.ndim > 0: + if alpha.ndim != 1 or alpha.shape[0] != B: + raise ValueError( + f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}" + ) + + if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0: + if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B: + raise ValueError( + f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}" + ) + + quadratic = loss_quadratic_samples_batch( + a, + b, + C1, + C2, + T, + loss=metric_quadratic, + symmetric=symmetric, + nx=nx, + logits=logits, + recompute_const=recompute_const, + ) + + linear = loss_linear_samples_batch(X, Y, T, metric=metric_linear) + + T1 = nx.sum(T, 2) + T2 = nx.sum(T, 1) + unbalanced = div_between_product_batch( + T1, + T2, + a, + b, + divergence=divergence, + nx=nx, + ) + + return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced + + def solve_gromov_batch( C1, C2, diff --git a/test/batch/test_solve_batch.py b/test/batch/test_solve_batch.py index 45a7e69fe..17d459a43 100644 --- a/test/batch/test_solve_batch.py +++ b/test/batch/test_solve_batch.py @@ -1,9 +1,8 @@ -"""Tests for module bregman on OT with bregman projections""" +"""Tests for module batch""" # Author: Remi Flamary -# Kilian Fatras -# Quang Huy Tran -# Eduardo Fernandes Montesuma +# Paul Krzakala +# Sonia Mazelet # # License: MIT License @@ -143,3 +142,28 @@ def test_backend(nx): M = dist_batch(X, X) solve_batch(M, reg=0.1, max_iter=10, tol=1e-5) solve_sample_batch(X, X, reg=0.1, max_iter=10, tol=1e-5) + + +def test_metric_default_parameters(): + """Check that all functions with default parameters run without error.""" + + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + X = rng.rand(batchsize, n, d) + M = dist_batch(X, X) + is_positive = M >= 0 + np.testing.assert_equal(is_positive.all(), True) + + # Solve batch + res = solve_batch(M, reg=0.1, max_iter=10, tol=1e-5) + + # Solve sample batch + res = solve_sample_batch(X, X, reg=0.1) + + # Compute loss + loss_linear_batch(M, res.plan) # recompute loss from plan + loss_linear_samples_batch(X, X, res.plan) # recompute loss from plan and samples + assert np.isfinite(loss_linear_batch(M, res.plan)).all() + assert np.isfinite(loss_linear_samples_batch(X, X, res.plan)).all() diff --git a/test/batch/test_solve_gromov_batch.py b/test/batch/test_solve_gromov_batch.py index e0029689b..cabb9219b 100644 --- a/test/batch/test_solve_gromov_batch.py +++ b/test/batch/test_solve_gromov_batch.py @@ -1,19 +1,32 @@ -"""Tests for module bregman on OT with bregman projections""" +"""Tests for module batch""" # Author: Remi Flamary -# Kilian Fatras -# Quang Huy Tran -# Eduardo Fernandes Montesuma +# Paul Krzakala +# Sonia Mazelet + + # # License: MIT License import numpy as np -from ot.batch import solve_gromov_batch, loss_quadratic_samples_batch +from ot.batch import ( + solve_gromov_batch, + loss_quadratic_batch, + loss_linear_batch, + loss_quadratic_samples_batch, +) from ot import solve_gromov from ot.batch._linear import dist_batch import pytest from itertools import product from ot.backend import torch +from ot.batch._quadratic import ( + tensor_batch, + loss_fugw_batch, + loss_fugw_samples_batch, + div_between_product_batch, +) +from ot.gromov._utils import div_between_product def test_solve_gromov_batch(): @@ -133,3 +146,199 @@ def test_backend(nx): C = np.random.randn(batchsize, n, n, d) C = nx.from_numpy(C) solve_gromov_batch(C1=C, C2=C, a=None, b=None, loss="sqeuclidean", logits=False) + + +@pytest.mark.parametrize("divergence", ["kl", "l2"]) +@pytest.mark.parametrize( + "metric_linear", ["sqeuclidean", "euclidean", "minkowski", "kl"] +) +@pytest.mark.parametrize("metric_quadratic", ["sqeuclidean", "kl"]) +def test_fugw_loss(divergence, metric_linear, metric_quadratic): + """Check that loss_fugw_batch and loss_fugw_samples_batch run without error.""" + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + C1 = rng.rand(batchsize, n, n, d) + C2 = rng.rand(batchsize, n, n, d) + X = rng.rand(batchsize, n, d) + Y = rng.rand(batchsize, n, d) + M = rng.rand(batchsize, n, n) + a = np.ones((batchsize, n)) + reg_marginals = 0 + T = rng.rand(batchsize, n, n) + L = tensor_batch(a=a, b=a, C1=C1, C2=C2, loss="sqeuclidean") + alpha = rng.rand() + reg_marginals = rng.rand() + logits = False if metric_quadratic == "kl" else None + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals, divergence=divergence + ) + loss_fugw_sample = loss_fugw_samples_batch( + a, + a, + C1, + C2, + X, + Y, + T, + alpha=alpha, + reg_marginals=reg_marginals, + divergence=divergence, + metric_linear=metric_linear, + metric_quadratic=metric_quadratic, + logits=logits, + ) + assert np.isfinite(loss_fugw).all() + assert np.isfinite(loss_fugw_sample).all() + + # check that alpha and reg_marginals can be passed as lists or arrays of shape (batchsize,) + alpha = rng.rand(batchsize) + reg_marginals = rng.rand(batchsize) + alpha_list = alpha.tolist() + reg_marginals_list = reg_marginals.tolist() + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals, divergence=divergence + ) + loss_fugw_sample = loss_fugw_samples_batch( + a, + a, + C1, + C2, + X, + Y, + T, + alpha=alpha, + reg_marginals=reg_marginals, + divergence=divergence, + metric_linear=metric_linear, + metric_quadratic=metric_quadratic, + logits=logits, + ) + loss_fugw_list = loss_fugw_batch( + a, + a, + L, + M, + T, + alpha=alpha_list, + reg_marginals=reg_marginals_list, + divergence=divergence, + ) + loss_fugw_sample_list = loss_fugw_samples_batch( + a, + a, + C1, + C2, + X, + Y, + T, + alpha=alpha_list, + reg_marginals=reg_marginals_list, + divergence=divergence, + metric_linear=metric_linear, + metric_quadratic=metric_quadratic, + logits=logits, + ) + + assert np.isfinite(loss_fugw).all() + assert np.isfinite(loss_fugw_sample).all() + assert np.isfinite(loss_fugw_list).all() + assert np.isfinite(loss_fugw_sample_list).all() + np.testing.assert_allclose(loss_fugw, loss_fugw_list) + np.testing.assert_allclose(loss_fugw_sample, loss_fugw_sample_list) + + # check that invalid alpha shape raise an error + alpha = rng.rand(batchsize + 1) + with pytest.raises(ValueError): + loss_fugw_batch(a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals) + with pytest.raises(ValueError): + loss_fugw_samples_batch( + a, + a, + C1, + C2, + X, + Y, + T, + alpha=alpha, + reg_marginals=reg_marginals, + divergence=divergence, + metric_linear=metric_linear, + metric_quadratic=metric_quadratic, + logits=logits, + ) + + # check that invalid rho shape raise an error + alpha = rng.rand(batchsize) + reg_marginals = rng.rand(batchsize + 1) + with pytest.raises(ValueError): + loss_fugw_batch(a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals) + with pytest.raises(ValueError): + loss_fugw_samples_batch( + a, + a, + C1, + C2, + X, + Y, + T, + alpha=alpha, + reg_marginals=reg_marginals, + divergence=divergence, + metric_linear=metric_linear, + metric_quadratic=metric_quadratic, + logits=logits, + ) + + +def test_valid_fugw_loss_endpoints(): + """Check that loss_fugw_batch gives the same results as solve_gromov_batch and solve_linear_batch for alpha=0 and alpha=1.""" + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + C1 = rng.rand(batchsize, n, n, d) + C2 = rng.rand(batchsize, n, n, d) + M = rng.rand(batchsize, n, n) + a = np.ones((batchsize, n)) + reg_marginals = 0 + T = rng.rand(batchsize, n, n) + L = tensor_batch(a=a, b=a, C1=C1, C2=C2, loss="sqeuclidean") + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=0.0, divergence="l2", reg_marginals=reg_marginals + ) + loss_linear = loss_linear_batch(M, T) + np.testing.assert_allclose(loss_fugw, loss_linear, atol=1e-5) + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=1.0, divergence="l2", reg_marginals=reg_marginals + ) + loss_gromov = loss_quadratic_batch(L, T, recompute_const=True) + np.testing.assert_allclose(loss_fugw, loss_gromov, atol=1e-5) + + +@pytest.mark.parametrize("divergence", ["kl", "l2"]) +def test_div_between_product(divergence): + batchsize = 2 + n = 4 + m = 3 + rng = np.random.RandomState(0) + mu = rng.rand(batchsize, n) + nu = rng.rand(batchsize, m) + alpha = rng.rand(batchsize, n) + beta = rng.rand(batchsize, m) + + res_batch = div_between_product_batch( + mu, nu, alpha, beta, divergence=divergence, nx=None + ) + res = np.array( + [ + div_between_product(mu[i], nu[i], alpha[i], beta[i], divergence) + for i in range(batchsize) + ] + ) + np.testing.assert_allclose(res_batch, res, atol=1e-5)