Quadratic problems

This example shows how to solve quadratic problems, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, the GWProblem, and the FGWProblem.

See also

Imports and data loading

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem, GWProblem

import numpy as np

import scanpy as sc

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'

Basic parameters

There are some parameters in quadratic problems which play the same role as in linear problems. Hence, we refer to Linear problems for the role of epsilon, tau_a, and tau_b. In fused quadratic problems (also referred to as Fused Gromov-Wasserstein) there is an additional parameter alpha defining the convex combination between the quadratic and the linear term, defined by joint_attr. Setting alpha = 1 only considers the pure quadratic problem, ignoring joint_attr. Setting alpha = 0 is not possible, and hence linear problems must be chosen.

gwp = GWProblem(adata)
gwp = gwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
)
gwp = gwp.solve(epsilon=1e-1)

fgwp = FGWProblem(adata)
fgwp = fgwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
    joint_attr="X_pca",
)
fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)

max_difference = np.max(
    np.abs(
        gwp["0", "1"].solution.transport_matrix
        - fgwp["0", "1"].solution.transport_matrix
    )
)
print(f"max difference: {max_difference:.6f}")
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
max difference: 0.021854

Low-rank solutions

Whenever the dataset is very large, the computational complexity can be reduced by setting rank to a positive integer [Scetbon et al., 2021]. In this case, epsilon can also be set to \(0\), while only the balanced case (\(\text{tau}_a = \text{tau}_b = 1\)) is supported. Moreover, the data has to be provided as point clouds, i.e., no precomputed cost matrix can be passed.

gwp = gwp.solve(epsilon=1e-2, rank=3)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Scaling the cost

scale_cost parameter works the same way as for linear problems, see Linear problems for more information. Note that all cost terms will be scaled by the same argument.