|
A differentiable JAX-based framework for vertex modeling and inverse design of epithelial tissues. |
Epithelial tissues dynamically reshape through local mechanical interactions among cells. Understanding, inferring, and designing these mechanics is a central challenge in developmental biology and biophysics. VertAX is a computational framework built to address this challenge.
VertAX is a framework for vertex-based modeling: it represents epithelial tissues as two-dimensional polygonal meshes in which cells are faces, junctions are edges, tricellular contacts are vertices, and mechanical equilibrium is defined by the minimum of a user-specified energy. Built on JAX, VertAX is designed not only for forward simulation, but also for inverse problems such as parameter inference and tissue design.
VertAX treats inverse modeling as a bilevel optimization problem:
Here,
In other words, VertAX repeatedly solves a mechanical equilibrium problem for a given parameter set

Figure: Bilevel optimization loop in VertAX.
-
🧩 Bilevel optimization framework
VertAX formulates inverse problems as nested optimization: an inner mechanical equilibrium problem and an outer parameter-learning problem. -
🔬 Multiple gradient strategies
Supports Automatic Differentiation (AD), Implicit Differentiation (ID), and Equilibrium Propagation (EP). -
🔁 Differentiable and non-differentiable workflows
VertAX supports fully differentiable pipelines, while EP also enables inverse modeling with simulators that are only accessible through repeated executions. -
⚡ GPU acceleration with JAX
JIT compilation and vectorization enable efficient simulations on CPU and GPU. -
🎨 Custom energies and costs in plain Python
Define your own mechanical models and inverse-design objectives without changing the library internals. -
🏗️ Two simulation modes
Supports both periodic tissues (bulk mechanics) and bounded tissues (finite clusters with curved interfaces). -
🔀 Automatic topology changes
Handles T1 neighbor exchanges during optimization. -
🔗 Seamless ML integration
Designed to work naturally with the JAX/Optax ecosystem.
We recommend installing VertAX in a virtual environment:
python -m venv .venv
source .venv/bin/activategit clone https://github.com/VirtualEmbryo/VertAX.git
cd vertax
pip install -e .pip install vertaxDependencies: JAX, Optax, SciPy (for Voronoi initialization), Matplotlib (for plotting).
For GPU support, install JAX with CUDA as described in the JAX docs before installing VertAX.
VertAX supports two complementary simulation modes, designed for different classes of epithelial mechanics problems. The periodic mode is best suited for bulk tissue dynamics without explicit external boundaries, while the bounded mode is designed for finite tissue clusters with curved free interfaces. Both modes share the same vertex-based formulation and optimization framework, but differ in how boundaries are represented and initialized.
VertAX implements and benchmarks three complementary methods for computing outer gradients through the implicit inner problem:
| Method | How it works | Pros | Cons |
|---|---|---|---|
| AD (Automatic Diff.) | Unrolls the inner optimization steps; forward-mode JVP via jax.jacfwd |
Exact for differentiable pipelines; easy in JAX | Cost scales with # iterations × # parameters |
| ID (Implicit Diff.) | Differentiates the optimality condition ∇ₓE=0 via Implicit Function Theorem; JVP or adjoint (VJP) variant | No unrolling; constant memory; exact near equilibrium | Requires Hessian solve; sensitive to ill-conditioning |
| EP (Equilibrium Prop.) | Estimates gradient from perturbed free and nudged equilibria; no backprop required | Memory-efficient; works with non-differentiable/incomplete solvers | Approximate; depends on perturbation size β |
In practice: AD and EP often recover similar parameter trends on synthetic inverse problems, while EP is especially attractive for simulators that cannot be made fully differentiable.
See the examples/ folder for in-depth examples:
| Notebook | Description |
|---|---|
inverse_modelling_example.ipynb |
Inverse modeling with periodic boundary conditions |
inverse_modelling_example_bounded.ipynb |
Inverse design with bounded cluster (convergent extension) |
See the docs/ folder for the whole documentation.
VertAX can also optimize model parameters to match a target geometry.
import math
import jax
import jax.numpy as jnp
import optax
from vertax import PbcBilevelOptimizer, PbcMesh, BilevelOptimizationMethod, plot_mesh
from vertax.cost import cost_v2v
from vertax.energy import energy_shape_factor_hetero
# --- Mesh setup ---
n_cells = 20
width = height = math.sqrt(n_cells)
# New mesh with Periodic Boundary Conditions and 20 cells.
mesh = PbcMesh.from_random_seeds(
nb_seeds=n_cells, width=width, height=height, random_key=0
)
# --- Attach parameters ---
mesh.vertices_params = jnp.zeros(mesh.nb_vertices)
mesh.edges_params = jnp.zeros(mesh.nb_half_edges) # not used here
mesh.faces_params = jnp.full(mesh.nb_faces, 3.7) # initial target shape factors
selected_faces = jnp.arange(mesh.nb_faces)
# --- Built-in energy function ---
def energy(vertTable, heTable, faceTable, _vert_params, _he_params, face_params):
return energy_shape_factor_hetero(
vertTable, heTable, faceTable,
width, height,
selected_faces,
face_params,
)
# --- Optimizer setup ---
optimizer = PbcBilevelOptimizer()
optimizer.loss_function_inner = energy
optimizer.inner_solver = optax.sgd(learning_rate=0.01)
optimizer.update_T1 = True
optimizer.min_dist_T1 = 0.005
# --- Relax the initial mesh ---
optimizer.inner_optimization(mesh)
# --- Create a target mesh with different face parameters ---
target = PbcMesh.copy_mesh(mesh)
key = jax.random.PRNGKey(1)
target.faces_params = 3.7 + 0.2 * jax.random.normal(key, shape=(target.nb_faces,))
target.vertices_params = jnp.zeros(target.nb_vertices)
target.edges_params = jnp.zeros(target.nb_half_edges)
optimizer.inner_optimization(target)
# --- Register the target ---
optimizer.vertices_target = target.vertices.copy()
optimizer.edges_target = target.edges.copy()
optimizer.faces_target = target.faces.copy()
# --- Outer loss and bilevel method ---
optimizer.loss_function_outer = cost_v2v
optimizer.outer_solver = optax.adam(learning_rate=1e-4, nesterov=True)
optimizer.bilevel_optimization_method = BilevelOptimizationMethod.EQUILIBRIUM_PROPAGATION
# --- Run bilevel optimization ---
for epoch in range(20):
optimizer.bilevel_optimization(mesh)
plot_mesh(mesh, title="Recovered mesh after inverse modeling")For full inverse-modeling examples, see Tutorials section.
If you use VertAX in your research, please cite:
@misc{pasqui2026vertaxdifferentiablevertexmodel, title={VertAX: a differentiable vertex model for learning epithelial tissue mechanics}, author={Alessandro Pasqui and Jim Martin Catacora Ocana and Anshuman Sinha and Matthieu Perez and Fabrice Delbary and Giorgio Gosti and Mattia Miotto and Domenico Caudo and Maxence Ernoult and Hervé Turlier}, year={2026}, eprint={2604.06896}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2604.06896}, }
This project received funding from the European Union’s Horizon 2020 research and innovation programme under the European Research Council (ERC) grant agreement no. 949267, and under the Marie Skłodowska-Curie grant agreement no. 945304 — Cofund AI4theSciences, hosted by PSL University. AP, JMCO, AS, FB, MP, FD and HT acknowledge support from CNRS and Collège de France
VertAX is distributed under the Creative Commons Attribution–ShareAlike 4.0 International (CC BY-SA 4.0) license.
You are free to share and adapt the material, provided that appropriate credit is given and that any derivative work is distributed under the same license.

