Quadratic problems (advanced)

This example shows an advanced quadratic problems usage, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, the GWProblem, and the FGWProblem.

See also

Imports and data loading

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import GWProblem

import scanpy as sc

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'
gwp = GWProblem(adata)
gwp = gwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
)
gwp
GWProblem[('0', '1')]

Threshold

The threshold parameter defines the convergence criterion. In the balanced setting the threshold denotes the deviation between prior and posterior marginals, while in the unbalanced setting the threshold corresponds to a Cauchy sequence stopping criterion.

Initializers

Different Initializers can help to improve convergence. For the full-rank case only the default initializer exists, hence the initializer argument must be set to None.

For low-rank problems the same initializers as for the linear low-rank solvers are available, and initializer_kwargs can be passed the same way, see Linear problems (advanced) for more information.

Number of iterations

To solve a quadratic optimal transport problem, a consecutively-updated linearized problem is solved. Here, min_iterations denotes a lower bound and max_iterations an upper bound on the number of outer iterations. If max_iterations is too low, the solver might not converge.

gwp = gwp.solve(epsilon=1e-1, min_iterations=0, max_iterations=1)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING  Solver did not converge                                                                                   

Linear solver keyword arguments

As mentioned above, each outer loop step of the Gromov-Wasserstein algorithm consists of solving a linear problem. Arguments for the linear solver can be specified via linear_solver_kwargs, keyword arguments for Sinkhorn in the full-rank case or keyword arguments for LRSinkhorn, respectively. This way, we can also set the minimum and maximum number of iterations for the linear solver:

ls_kwargs = {"min_iterations": 10, "max_iterations": 1000, "threshold": 0.01}
gwp = gwp.solve(
    epsilon=1e-1,
    threshold=0.1,
    min_iterations=2,
    max_iterations=20,
    linear_solver_kwargs=ls_kwargs,
)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Low-rank hyperparameters

The parameters gamma and gamma_rescale are the same as in the linear case, see example Linear problems (advanced).

Keyword arguments and implementation details

Whenever the solve method of a quadratic problem is called, a backend-specific quadratic solver is instantiated. Currently, ott is supported, its corresponding quadratic solvers is GromovWasserstein, handling both the full-rank and the low-rank case. moscot wraps this class in GWSolver, handling both the purely quadratic and the fused quadratic problem.