diff --git a/chapter3/demo_smoothed_tv_inpainting.ipynb b/chapter3/demo_smoothed_tv_inpainting.ipynb new file mode 100644 index 00000000..02d9800e --- /dev/null +++ b/chapter3/demo_smoothed_tv_inpainting.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1d540f69", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "# Smoothed TV image inpainting\n", + "\n", + "```{admonition} Download sources\n", + ":class: download\n", + "* {download}`Python script <./demo_smoothed_tv_inpainting.py>`\n", + "* {download}`Jupyter notebook <./demo_smoothed_tv_inpainting.ipynb>`\n", + "```\n", + "\n", + "This demo solves a variational image inpainting problem on the\n", + "unit square. A synthetic image is masked on an irregular interior region,\n", + "and the missing values are reconstructed using smoothed total\n", + "variation (TV) regularization.\n", + "\n", + "## Problem Definition\n", + "\n", + "Let $\\Omega = [0,1]^2$ be the image domain. We define:\n", + "\n", + "- $u_{\\mathrm{true}}$: synthetic ground-truth image\n", + "- $m$: mask, equal to 1 on known data and 0 on the missing region\n", + "- $f = m u_{\\mathrm{true}}$: observed incomplete image\n", + "- $u$: reconstructed image\n", + "\n", + "We compute $u$ by minimising\n", + "\n", + "$$\n", + "J(u)= {1 \\over 2}\\beta \\int_\\Omega m(u-f)^2\\,\\mathrm{d}x\n", + "+ \\alpha \\int_\\Omega \\sqrt{||\\nabla u||^2 + \\varepsilon^2}~\\mathrm{d}x.\n", + "$$\n", + "\n", + "The first term enforces agreement with the known image data, while\n", + "the second term is a smoothed total variation regularisation term.\n", + "It promotes piecewise smooth solution and preserves edges\n", + "$\\alpha$ and $\\beta$ control the balance between the data fidelity\n", + "(fit to f) and smoothness.\n", + "The parameter $\\varepsilon>0$ smooths the TV function so that\n", + "it is differentiable and can be solved with Newton type methods\n", + "\n", + "## Discretization\n", + "We discretize the problem using\n", + "- a first order Lagrange finite element space\n", + "- a triangular mesh of the unit square\n", + "\n", + "## Implementation\n", + "\n", + "We use a first-order Lagrange space on a triangular mesh\n", + "of the unit square.\n", + "The nonlinear problem is solved with\n", + "{py:class}`PETSc SNES` through\n", + "{py:class}`NonlinearProblem `." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cfdd96e", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from mpi4py import MPI\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.tri as mtri\n", + "import numpy as np\n", + "\n", + "import ufl\n", + "from dolfinx import fem, mesh\n", + "from dolfinx.fem.petsc import NonlinearProblem\n" + ] + }, + { + "cell_type": "markdown", + "id": "f7fab203", + "metadata": {}, + "source": [ + "We discretize the domain $\\Omega =[0,1]^2$ using a triangular\n", + "mesh, where `nx` and `ny` control the resolution of the mesh." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fa206d9", + "metadata": {}, + "outputs": [], + "source": [ + "nx = 128\n", + "ny = 128\n", + "msh = mesh.create_unit_square(MPI.COMM_WORLD, nx, ny)" + ] + }, + { + "cell_type": "markdown", + "id": "f6c9a0cf", + "metadata": {}, + "source": [ + "We use first order Lagrange elements for discretizing the image.\n", + "In this space, the DOFs are the values of u at mesh vertices\n", + "the solution is continous but has piecewise constant gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed2da823", + "metadata": {}, + "outputs": [], + "source": [ + "V = fem.functionspace(msh, (\"Lagrange\", 1))" + ] + }, + { + "cell_type": "markdown", + "id": "00cacb25", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "### Ground Truth image $u_{true}$\n", + "We define a synthetic binary image\n", + "\n", + "$$\n", + " u_{true}=\\begin{cases}\n", + " 1 & \\text{ if } (x,y) \\text{ is inside a square}\\\\\n", + " 0 &\\text{ otherwise}\n", + " \\end{cases}\n", + "$$\n", + "\n", + "The square is defined as $0.2 0.2) & (X < 0.8) & (Y > 0.2) & (Y < 0.8)).astype(np.float64)" + ] + }, + { + "cell_type": "markdown", + "id": "4d71b42e", + "metadata": {}, + "source": [ + "### Mask $m(x,y)$\n", + "The mask defines which pixel are known and which are missing\n", + "\n", + "$$\n", + " m(x,y)=\\begin{cases}\n", + " 1& \\text{ known data}\\\\\n", + " 0 & \\text{ missing region}\n", + " \\end{cases}\n", + "$$\n", + "\n", + "We construct a mask with random \"holes\" inside the square\n", + "* small circular regions are removed and set to 0\n", + "* everyhere else remains known (1)" + ] + }, + { + "cell_type": "markdown", + "id": "ea7f3dd6", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "This creates a challenging inpainting problem as:\n", + "* many small missing regions\n", + "* irregular geometry\n", + "\n", + "The solver must reconstruct these missing values\n", + "using smoothness (TV regularization)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94d4f171", + "metadata": {}, + "outputs": [], + "source": [ + "def mask_function(x):\n", + " \"\"\"Create a mask with random circular holes inside the square.\"\"\"\n", + " X = x[0]\n", + " Y = x[1]\n", + " # all pixels known\n", + " mask = np.ones_like(X, dtype=np.float64)\n", + " # number of speckles\n", + " num_speckles = 25\n", + " # random centers\n", + " generator = np.random.Generator(np.random.MT19937(0)) # random seed for reproducibility\n", + "\n", + " cx = generator.uniform(0.25, 0.75, num_speckles)\n", + " cy = generator.uniform(0.25, 0.75, num_speckles)\n", + " # random radii (small + varied)\n", + " radii = generator.uniform(0.012, 0.035, num_speckles)\n", + " # create holes. mask =0 inside circles\n", + " for i in range(num_speckles):\n", + " r2 = (X - cx[i]) ** 2 + (Y - cy[i]) ** 2\n", + " mask[r2 < radii[i] ** 2] = 0.0\n", + " return mask" + ] + }, + { + "cell_type": "markdown", + "id": "36f05fd5", + "metadata": {}, + "source": [ + "We interpolate the exact image and the mask into the finite element\n", + "space, and construct the observed damaged image,\n", + "where $u_{true}$ is our true image, $m: \\Omega \\to \\mathbb{R}$\n", + "is the mask, $f: \\Omega \\to \\mathbb{R}$ is the observed damaged image,\n", + "and $u:\\Omega \\to \\mathbb{R}$ is the reconstructed image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "727bb0e2", + "metadata": {}, + "outputs": [], + "source": [ + "u_true = fem.Function(V, name=\"true_image\")\n", + "u_true.interpolate(true_image)\n", + "m = fem.Function(V, name=\"mask\")\n", + "m.interpolate(mask_function)\n", + "f = fem.Function(V, name=\"observed_image\")\n", + "f.x.array[:] = m.x.array * u_true.x.array\n", + "u = fem.Function(V, name=\"reconstructed_image\")\n", + "u.x.array[:] = f.x.array.copy()" + ] + }, + { + "cell_type": "markdown", + "id": "f8b1b638", + "metadata": {}, + "source": [ + "We now define the nonlinear variational problem corresponding to the\n", + "smoothed total variation regularised inpainting model.\n", + "\n", + "The Euler-Lagrange equation for $J(u)$ leads to the weak form\n", + "Find $u\\in V$ such that\n", + "\n", + "$$\n", + "\\int_\\Omega m(u-f)v\\,\\mathrm{d}x\n", + "+ \\alpha \\int_\\Omega\n", + "{\\nabla u\\cdot\\nabla v \\over \\sqrt{||\\nabla u||^2+\\varepsilon^2}}\n", + "\\,\\mathrm{d}x = 0\n", + "$$\n", + "\n", + "for all test functions $v$. This is a nonlinear problem due to\n", + "the TV term\n", + "Total variation is usually defined as $\\vert\\vert\\nabla u\\vert\\vert$,\n", + "but in practice one uses a smoothed version to allow for differentiation\n", + "and Newton type solvers:\n", + "\n", + "$$\n", + " TV = \\sqrt{\\vert\\vert\\nabla u\\vert\\vert^2 +\\varepsilon^2}\n", + "$$\n", + "\n", + "where # $\\varepsilon$ is the smoothing of the TV:\n", + "* large $\\varepsilon$ smoother more like quadratic diffusion\n", + "* small $\\varepsilon$ closer to true TV edge pereserving" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13661a58", + "metadata": {}, + "outputs": [], + "source": [ + "alpha = fem.Constant(msh, 0.003)\n", + "beta = fem.Constant(msh, 1.0)\n", + "eps = fem.Constant(msh, 1.0e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "f5d8efaa", + "metadata": {}, + "source": [ + "Smoothed TV inpainting energy functional.\n", + "We define the energy J(u) and use ufl.derivative to obtain\n", + "the residual form F.\n", + "\n", + "$$\n", + "J(u) = {1 \\over 2}\\beta\\int_\\Omega m(u-f)^2\\,dx\n", + "+ \\alpha\\int_\\Omega \\sqrt{||\\nabla u||^2+\\varepsilon^2}\\,dx\n", + "$$\n", + "\n", + "Taking the first variation gives the weak form F(u; v).\n", + "\n", + "$$\n", + "F(u; v) =\n", + "\\beta\\int_\\Omega m(u-f)v\\,dx\n", + "+ \\alpha\\int_\\Omega\n", + "{\\nabla u\\cdot\\nabla v\n", + "\\over\n", + "\\sqrt{||\\nabla u||^2+\\varepsilon^2}}\\,dx\n", + "= 0 \\quad \\forall v\\in V.\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f719988", + "metadata": {}, + "outputs": [], + "source": [ + "v = ufl.TestFunction(V)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0055665b", + "metadata": {}, + "outputs": [], + "source": [ + "J_energy = (\n", + " 0.5 * beta * m * (u - f) ** 2 * ufl.dx\n", + " + alpha * ufl.sqrt(ufl.inner(ufl.grad(u), ufl.grad(u)) + eps**2) * ufl.dx\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fbade6e", + "metadata": {}, + "outputs": [], + "source": [ + "F = ufl.derivative(J_energy, u, v)" + ] + }, + { + "cell_type": "markdown", + "id": "77b7387d", + "metadata": {}, + "source": [ + "This formulation is based on total variation (TV) regulaization\n", + "for image denoising and inpainting\n", + "{cite:t}`tv-RUDIN1992TV,tv-CHAN2001TV`." + ] + }, + { + "cell_type": "markdown", + "id": "2d37679d", + "metadata": {}, + "source": [ + "A nonlinear PETSc problem is created and solved with a Newton line-search\n", + "method, with an LU factorization for the linearized system\n", + "$F'(u_k) s= -F(u_k)$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0887a5f4", + "metadata": {}, + "outputs": [], + "source": [ + "petsc_options = {\n", + " \"snes_type\": \"newtonls\",\n", + " \"snes_linesearch_type\": \"bt\",\n", + " \"snes_rtol\": 1.0e-8,\n", + " \"snes_atol\": 1.0e-8,\n", + " \"snes_max_it\": 1000,\n", + " \"ksp_type\": \"preonly\",\n", + " \"pc_type\": \"lu\",\n", + "}\n", + "\n", + "problem = NonlinearProblem(\n", + " F,\n", + " u,\n", + " bcs=[],\n", + " petsc_options_prefix=\"tv_inpainting_\",\n", + " petsc_options=petsc_options,\n", + ")\n", + "\n", + "problem.solve()" + ] + }, + { + "cell_type": "markdown", + "id": "07a454e4", + "metadata": {}, + "source": [ + "## Model Validation and Results\n", + "These diagnostics asses\n", + "1. whether the nonlinear Newton/SNES solve converged\n", + "2. whether the variational objective decreased\n", + "3. how accurate the reconstruction is globally and in the hole region" + ] + }, + { + "cell_type": "markdown", + "id": "cd74b7a9", + "metadata": {}, + "source": [ + "FEM Metrics\n", + "Global number of degrees of freedom reports the size of the\n", + "finite element discretisation H1 seminorm error measures the\n", + "gradient error\n", + "\n", + "$$\n", + " \\vert\\vert\\nabla(u-u_{true})\\vert\\vert_{L_2 (\\Omega)}\n", + "$$\n", + "\n", + "This is useful as TV regulization is gradient based.\n", + "Smaller values mean the reconstruction recovers edge structure better" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f76353ca", + "metadata": {}, + "outputs": [], + "source": [ + "num_dofs = V.dofmap.index_map.size_global\n", + "h1_semi_error = fem.assemble_scalar(\n", + " fem.form(ufl.inner(ufl.grad(u - u_true), ufl.grad(u - u_true)) * ufl.dx)\n", + ")\n", + "h1_semi_error = np.sqrt(msh.comm.allreduce(h1_semi_error, op=MPI.SUM))" + ] + }, + { + "cell_type": "markdown", + "id": "b8b387f0", + "metadata": {}, + "source": [ + "Reconstruction Errors\n", + "Data fidelity (known region only):\n", + "\n", + "$$\n", + " \\sqrt{\\vert\\vert m(u-f) \\vert\\vert_{L_2 \\Omega}}\n", + "$$\n", + "\n", + "measures the agreement with the known image data.\n", + "Smaller values mean the reconstruction matches the observe pixels better." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4539f66e", + "metadata": {}, + "outputs": [], + "source": [ + "data_error = fem.assemble_scalar(fem.form(m * (u - f) ** 2 * ufl.dx))\n", + "data_error = np.sqrt(msh.comm.allreduce(data_error, op=MPI.SUM))" + ] + }, + { + "cell_type": "markdown", + "id": "fc2df069", + "metadata": {}, + "source": [ + "TV seminorm\n", + "\n", + "$$\n", + " \\int_{\\Omega}\\sqrt{\\vert\\vert\\nabla u \\vert\\vert^2\n", + " +\\varepsilon^2}~\\mathrm{d}x\n", + "$$\n", + "\n", + "This is the regularization term in the objective\n", + "Smaller values mean a smoother reconstruction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d5b95c5", + "metadata": {}, + "outputs": [], + "source": [ + "tv_energy = fem.assemble_scalar(\n", + " fem.form(ufl.sqrt(ufl.inner(ufl.grad(u), ufl.grad(u)) + eps**2) * ufl.dx)\n", + ")\n", + "tv_energy = msh.comm.allreduce(tv_energy, op=MPI.SUM)" + ] + }, + { + "cell_type": "markdown", + "id": "8db1391e", + "metadata": {}, + "source": [ + "True error\n", + "\n", + "$$\n", + " \\sqrt{\\vert\\vert u-u_{true} \\vert\\vert_{L_2 \\Omega}}\n", + "$$\n", + "\n", + "Measures overall reconstruction accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a1b0ffb", + "metadata": {}, + "outputs": [], + "source": [ + "true_error = fem.assemble_scalar(fem.form((u - u_true) ** 2 * ufl.dx))\n", + "true_error = np.sqrt(msh.comm.allreduce(true_error, op=MPI.SUM))" + ] + }, + { + "cell_type": "markdown", + "id": "c0ed2b19", + "metadata": {}, + "source": [ + "Hole error\n", + "\n", + "$$\n", + " \\sqrt{\\vert\\vert (1-m)(u-u_{true}) \\vert\\vert_{L_2 \\Omega}}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec2f6e97", + "metadata": {}, + "outputs": [], + "source": [ + "hole_error = fem.assemble_scalar(fem.form((1 - m) * (u - u_true) ** 2 * ufl.dx))\n", + "hole_error = np.sqrt(msh.comm.allreduce(hole_error, op=MPI.SUM))" + ] + }, + { + "cell_type": "markdown", + "id": "59550d88", + "metadata": {}, + "source": [ + "Image quality metric\n", + "PSNR (peak signal to noise ratio), standard imaging metric\n", + "since the image range is [0,1], we use\n", + "\n", + "$$\n", + " PSNR=10\\log_{10}(1/MSE)\n", + "$$\n", + "\n", + "Larger PSNR means better reconstruction quality" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc15b0c2", + "metadata": {}, + "outputs": [], + "source": [ + "mse = np.mean((u.x.array - u_true.x.array) ** 2)\n", + "if mse == 0:\n", + " psnr = np.inf\n", + "else:\n", + " psnr = 10.0 * np.log10(1.0 / mse)" + ] + }, + { + "cell_type": "markdown", + "id": "89d6a263", + "metadata": {}, + "source": [ + "Newton Linesearch metrics\n", + "Measure whether the nonlinear solve succeeded\n", + "* we want a positive converged reason\n", + "* a small final residual norm\n", + "* a reasonable number of iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87eaa8f8", + "metadata": {}, + "outputs": [], + "source": [ + "snes = problem.solver\n", + "reason = snes.getConvergedReason()\n", + "iters = snes.getIterationNumber()\n", + "final_residual = snes.getFunctionNorm()" + ] + }, + { + "cell_type": "markdown", + "id": "f5c78195", + "metadata": {}, + "source": [ + "Objective values\n", + "Comparing the initial objective J(f)\n", + "with the final objective J(u)\n", + "\n", + "$$\n", + " J(v)={1\\over 2}\\beta \\int m(v-f)^2 dx\n", + " +\\alpha \\int \\sqrt{||\\nabla v||^2+\\varepsilon^2}\n", + "$$\n", + "\n", + "A decrease in the objective show that the nonlinear optimization\n", + "improved the damaged image undeer the smoothed TV model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ce4e5a0", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "objective_value = 0.5 * float(beta) * data_error**2 + float(alpha) * tv_energy\n", + "if reason > 0:\n", + " status = \"converged\"\n", + "else:\n", + " status = \"not converged\"\n", + "\n", + "u0 = fem.Function(V)\n", + "u0.x.array[:] = f.x.array.copy()\n", + "\n", + "J0_data = fem.assemble_scalar(fem.form(m * (u0 - f) ** 2 * ufl.dx))\n", + "J0_data = msh.comm.allreduce(J0_data, op=MPI.SUM)\n", + "\n", + "J0_tv = fem.assemble_scalar(\n", + " fem.form(ufl.sqrt(ufl.inner(ufl.grad(u0), ufl.grad(u0)) + eps**2) * ufl.dx)\n", + ")\n", + "J0_tv = msh.comm.allreduce(J0_tv, op=MPI.SUM)\n", + "\n", + "J0 = 0.5 * float(beta) * J0_data + float(alpha) * J0_tv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65e16ee1", + "metadata": {}, + "outputs": [], + "source": [ + "# Printing statments for validation and metrics\n", + "# If on main process\n", + "if msh.comm.rank == 0:\n", + " print(\"---Smoothed TV inpainting results---\")\n", + "\n", + " print(\"--FEM Metrics--\")\n", + " print(f\"Global DOFs: {num_dofs}\")\n", + " print(f\"H1 seminorm error: {h1_semi_error}\")\n", + "\n", + " print(\"--Newton Linesearch:--\")\n", + " print(\"-Optimization:-\")\n", + " print(f\"Initial objective J(f): {J0:.4e}\")\n", + " print(f\"Final objective J(u): {objective_value:.4e}\")\n", + " print(f\"Relative decrease: {(J0 - objective_value) / J0:.2%}\")\n", + "\n", + " print(\"-Solver convergence:-\")\n", + " print(f\"SNES iteration: {iters}\")\n", + " print(f\"SNES final residual norm: {final_residual:.4e}\")\n", + " print(f\"SNES status: {status}\")\n", + " print(f\"SNES converged reason: {reason}\")\n", + "\n", + " print(\"---Reconstruction Quality:---\")\n", + " print(f\"Data error (known region): {data_error:.4e}\")\n", + " print(f\"TV seminorm: {tv_energy:.4e}\")\n", + " print(f\"True L2 error: {true_error:.4e}\")\n", + " print(f\"Hole error: {hole_error:.4e}\")\n", + " print(f\"PSNR: {psnr:.2f} dB\")" + ] + }, + { + "cell_type": "markdown", + "id": "67026512", + "metadata": {}, + "source": [ + "## Visualization\n", + "We construct fields that allow us to visually asses the quality\n", + "of the reconstruction\n", + "$u-u_{true}$ is the global reconstruction error\n", + "$(1-m)(u-u_{true})$ is the hole error, restriced to the missing regions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bece848f", + "metadata": {}, + "outputs": [], + "source": [ + "u_minus_u_true = fem.Function(V)\n", + "u_minus_u_true.x.array[:] = u.x.array - u_true.x.array\n", + "\n", + "hole_error_field = fem.Function(V)\n", + "hole_error_field.x.array[:] = (1.0 - m.x.array) * (u.x.array - u_true.x.array)" + ] + }, + { + "cell_type": "markdown", + "id": "6e021441", + "metadata": {}, + "source": [ + "### FEM to matplotlib\n", + "The solution u in FEM is represented by values at degrees\n", + "of freedom (DOFs), not on a regular grid\n", + "To plot in matplotlib\n", + "1. extract the coordinates of the DOFs\n", + "2. extract the function-space dofmap connectivity (triangles)\n", + "3. build a Triangulation object\n", + "This allows matplotlib to render the piecewise linear FEM solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "313f897c", + "metadata": {}, + "outputs": [], + "source": [ + "coords = V.tabulate_dof_coordinates()\n", + "x, y = coords[:, 0], coords[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a5a2ec4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "triangles = V.dofmap.list\n", + "triang = mtri.Triangulation(x, y, triangles)" + ] + }, + { + "cell_type": "markdown", + "id": "7741b61f", + "metadata": {}, + "source": [ + "Plotting\n", + "We use tripcolor to plot scalar fields defined on a triangulated mesh\n", + "shading= \"flat\" shows piecewise constant coloring per triangle\n", + "which better reflects the discrete FEM representations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8729fbeb", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "def plot_field(ax, data, title, fig, cmap=\"viridis\", vmin=0.0, vmax=1.0):\n", + " \"\"\"Plot a scalar field on a triangulated mesh.\"\"\"\n", + " im = ax.tripcolor(triang, data, shading=\"flat\", cmap=cmap, vmin=vmin, vmax=vmax)\n", + " ax.set_title(title)\n", + " ax.set_aspect(\"equal\")\n", + " fig.colorbar(im, ax=ax)\n", + "\n", + "\n", + "fig, axes = plt.subplots(2, 3, figsize=(12, 8))\n", + "# $u_{true } $ground truth image\n", + "plot_field(axes[0, 0], u_true.x.array, \"u_true\", fig)\n", + "# m, mask with known (1) and missing (0) regions\n", + "plot_field(axes[0, 1], m.x.array, \"mask\", fig, cmap=\"gray\")\n", + "# f is the damaged image\n", + "plot_field(axes[0, 2], f.x.array, \"f\", fig)\n", + "# u is the reconstructed image\n", + "plot_field(axes[1, 0], u.x.array, \"u\", fig)\n", + "# Global error\n", + "lim = np.max(np.abs(u_minus_u_true.x.array))\n", + "# $u-u_{true}$ is the global reconstruction error\n", + "plot_field(\n", + " axes[1, 1], u_minus_u_true.x.array, \"u - u_true\", fig, cmap=\"coolwarm\", vmin=-lim, vmax=lim\n", + ")\n", + "# Hole only errors\n", + "lim = np.max(np.abs(hole_error_field.x.array))\n", + "# Hole only error restricted to the missing regions\n", + "plot_field(\n", + " axes[1, 2],\n", + " hole_error_field.x.array,\n", + " \"hole-only error\",\n", + " fig,\n", + " cmap=\"coolwarm\",\n", + " vmin=-lim,\n", + " vmax=lim,\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "1b56b8c1", + "metadata": {}, + "source": [ + "## References\n", + "```{bibliography}\n", + " :filter: cited\n", + " :labelprefix:\n", + " :keyprefix: tv-\n", + "```" + ] + } + ], + "metadata": { + "jupytext": { + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/chapter3/demo_smoothed_tv_inpainting.py b/chapter3/demo_smoothed_tv_inpainting.py new file mode 100644 index 00000000..e8654f8d --- /dev/null +++ b/chapter3/demo_smoothed_tv_inpainting.py @@ -0,0 +1,506 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.13.6 +# --- + +# # Smoothed TV image inpainting +# +# ```{admonition} Download sources +# :class: download +# * {download}`Python script <./demo_smoothed_tv_inpainting.py>` +# * {download}`Jupyter notebook <./demo_smoothed_tv_inpainting.ipynb>` +# ``` +# +# This demo solves a variational image inpainting problem on the +# unit square. A synthetic image is masked on an irregular interior region, +# and the missing values are reconstructed using smoothed total +# variation (TV) regularization. +# +# ## Problem Definition +# +# Let $\Omega = [0,1]^2$ be the image domain. We define: +# +# - $u_{\mathrm{true}}$: synthetic ground-truth image +# - $m$: mask, equal to 1 on known data and 0 on the missing region +# - $f = m u_{\mathrm{true}}$: observed incomplete image +# - $u$: reconstructed image +# +# We compute $u$ by minimising +# +# $$ +# J(u)= {1 \over 2}\beta \int_\Omega m(u-f)^2\,\mathrm{d}x +# + \alpha \int_\Omega \sqrt{||\nabla u||^2 + \varepsilon^2}~\mathrm{d}x. +# $$ +# +# The first term enforces agreement with the known image data, while +# the second term is a smoothed total variation regularisation term. +# It promotes piecewise smooth solution and preserves edges +# $\alpha$ and $\beta$ control the balance between the data fidelity +# (fit to f) and smoothness. +# The parameter $\varepsilon>0$ smooths the TV function so that +# it is differentiable and can be solved with Newton type methods +# +# ## Discretization +# We discretize the problem using +# - a first order Lagrange finite element space +# - a triangular mesh of the unit square +# +# ## Implementation +# +# We use a first-order Lagrange space on a triangular mesh +# of the unit square. +# The nonlinear problem is solved with +# {py:class}`PETSc SNES` through +# {py:class}`NonlinearProblem `. + + +# + +from mpi4py import MPI + +import matplotlib.pyplot as plt +import matplotlib.tri as mtri +import numpy as np + +import ufl +from dolfinx import fem, mesh +from dolfinx.fem.petsc import NonlinearProblem + +# - + + +# We discretize the domain $\Omega =[0,1]^2$ using a triangular +# mesh, where `nx` and `ny` control the resolution of the mesh. + +nx = 128 +ny = 128 +msh = mesh.create_unit_square(MPI.COMM_WORLD, nx, ny) + +# We use first order Lagrange elements for discretizing the image. +# In this space, the DOFs are the values of u at mesh vertices +# the solution is continous but has piecewise constant gradient + +V = fem.functionspace(msh, ("Lagrange", 1)) + +# ### Ground Truth image $u_{true}$ +# We define a synthetic binary image +# +# $$ +# u_{true}=\begin{cases} +# 1 & \text{ if } (x,y) \text{ is inside a square}\\ +# 0 &\text{ otherwise} +# \end{cases} +# $$ +# +# The square is defined as $0.2 0.2) & (X < 0.8) & (Y > 0.2) & (Y < 0.8)).astype(np.float64) + + +# ### Mask $m(x,y)$ +# The mask defines which pixel are known and which are missing +# +# $$ +# m(x,y)=\begin{cases} +# 1& \text{ known data}\\ +# 0 & \text{ missing region} +# \end{cases} +# $$ +# +# We construct a mask with random "holes" inside the square +# * small circular regions are removed and set to 0 +# * everyhere else remains known (1) + +# This creates a challenging inpainting problem as: +# * many small missing regions +# * irregular geometry +# +# The solver must reconstruct these missing values +# using smoothness (TV regularization) + + +def mask_function(x): + """Create a mask with random circular holes inside the square.""" + X = x[0] + Y = x[1] + # all pixels known + mask = np.ones_like(X, dtype=np.float64) + # number of speckles + num_speckles = 25 + # random centers + generator = np.random.Generator(np.random.MT19937(0)) # random seed for reproducibility + + cx = generator.uniform(0.25, 0.75, num_speckles) + cy = generator.uniform(0.25, 0.75, num_speckles) + # random radii (small + varied) + radii = generator.uniform(0.012, 0.035, num_speckles) + # create holes. mask =0 inside circles + for i in range(num_speckles): + r2 = (X - cx[i]) ** 2 + (Y - cy[i]) ** 2 + mask[r2 < radii[i] ** 2] = 0.0 + return mask + + +# We interpolate the exact image and the mask into the finite element +# space, and construct the observed damaged image, +# where $u_{true}$ is our true image, $m: \Omega \to \mathbb{R}$ +# is the mask, $f: \Omega \to \mathbb{R}$ is the observed damaged image, +# and $u:\Omega \to \mathbb{R}$ is the reconstructed image. + +u_true = fem.Function(V, name="true_image") +u_true.interpolate(true_image) +m = fem.Function(V, name="mask") +m.interpolate(mask_function) +f = fem.Function(V, name="observed_image") +f.x.array[:] = m.x.array * u_true.x.array +u = fem.Function(V, name="reconstructed_image") +u.x.array[:] = f.x.array.copy() + +# We now define the nonlinear variational problem corresponding to the +# smoothed total variation regularised inpainting model. +# +# The Euler-Lagrange equation for $J(u)$ leads to the weak form +# Find $u\in V$ such that +# +# $$ +# \int_\Omega m(u-f)v\,\mathrm{d}x +# + \alpha \int_\Omega +# {\nabla u\cdot\nabla v \over \sqrt{||\nabla u||^2+\varepsilon^2}} +# \,\mathrm{d}x = 0 +# $$ +# +# for all test functions $v$. This is a nonlinear problem due to +# the TV term +# Total variation is usually defined as $\vert\vert\nabla u\vert\vert$, +# but in practice one uses a smoothed version to allow for differentiation +# and Newton type solvers: +# +# $$ +# TV = \sqrt{\vert\vert\nabla u\vert\vert^2 +\varepsilon^2} +# $$ +# +# where # $\varepsilon$ is the smoothing of the TV: +# * large $\varepsilon$ smoother more like quadratic diffusion +# * small $\varepsilon$ closer to true TV edge pereserving + +alpha = fem.Constant(msh, 0.003) +beta = fem.Constant(msh, 1.0) +eps = fem.Constant(msh, 1.0e-4) + +# Smoothed TV inpainting energy functional. +# We define the energy J(u) and use ufl.derivative to obtain +# the residual form F. +# +# $$ +# J(u) = {1 \over 2}\beta\int_\Omega m(u-f)^2\,dx +# + \alpha\int_\Omega \sqrt{||\nabla u||^2+\varepsilon^2}\,dx +# $$ +# +# Taking the first variation gives the weak form F(u; v). +# +# $$ +# F(u; v) = +# \beta\int_\Omega m(u-f)v\,dx +# + \alpha\int_\Omega +# {\nabla u\cdot\nabla v +# \over +# \sqrt{||\nabla u||^2+\varepsilon^2}}\,dx +# = 0 \quad \forall v\in V. +# $$ + +v = ufl.TestFunction(V) + +J_energy = ( + 0.5 * beta * m * (u - f) ** 2 * ufl.dx + + alpha * ufl.sqrt(ufl.inner(ufl.grad(u), ufl.grad(u)) + eps**2) * ufl.dx +) + +F = ufl.derivative(J_energy, u, v) + +# This formulation is based on total variation (TV) regulaization +# for image denoising and inpainting +# {cite:t}`tv-RUDIN1992TV,tv-CHAN2001TV`. + +# A nonlinear PETSc problem is created and solved with a Newton line-search +# method, with an LU factorization for the linearized system +# $F'(u_k) s= -F(u_k)$. + +# + +petsc_options = { + "snes_type": "newtonls", + "snes_linesearch_type": "bt", + "snes_rtol": 1.0e-8, + "snes_atol": 1.0e-8, + "snes_max_it": 1000, + "ksp_type": "preonly", + "pc_type": "lu", +} + +problem = NonlinearProblem( + F, + u, + bcs=[], + petsc_options_prefix="tv_inpainting_", + petsc_options=petsc_options, +) + +problem.solve() +# - + +# ## Model Validation and Results +# These diagnostics asses +# 1. whether the nonlinear Newton/SNES solve converged +# 2. whether the variational objective decreased +# 3. how accurate the reconstruction is globally and in the hole region + +# FEM Metrics +# Global number of degrees of freedom reports the size of the +# finite element discretisation H1 seminorm error measures the +# gradient error +# +# $$ +# \vert\vert\nabla(u-u_{true})\vert\vert_{L_2 (\Omega)} +# $$ +# +# This is useful as TV regulization is gradient based. +# Smaller values mean the reconstruction recovers edge structure better + +num_dofs = V.dofmap.index_map.size_global +h1_semi_error = fem.assemble_scalar( + fem.form(ufl.inner(ufl.grad(u - u_true), ufl.grad(u - u_true)) * ufl.dx) +) +h1_semi_error = np.sqrt(msh.comm.allreduce(h1_semi_error, op=MPI.SUM)) + +# Reconstruction Errors +# Data fidelity (known region only): +# +# $$ +# \sqrt{\vert\vert m(u-f) \vert\vert_{L_2 \Omega}} +# $$ +# +# measures the agreement with the known image data. +# Smaller values mean the reconstruction matches the observe pixels better. + +data_error = fem.assemble_scalar(fem.form(m * (u - f) ** 2 * ufl.dx)) +data_error = np.sqrt(msh.comm.allreduce(data_error, op=MPI.SUM)) + +# TV seminorm +# +# $$ +# \int_{\Omega}\sqrt{\vert\vert\nabla u \vert\vert^2 +# +\varepsilon^2}~\mathrm{d}x +# $$ +# +# This is the regularization term in the objective +# Smaller values mean a smoother reconstruction + +tv_energy = fem.assemble_scalar( + fem.form(ufl.sqrt(ufl.inner(ufl.grad(u), ufl.grad(u)) + eps**2) * ufl.dx) +) +tv_energy = msh.comm.allreduce(tv_energy, op=MPI.SUM) + +# True error +# +# $$ +# \sqrt{\vert\vert u-u_{true} \vert\vert_{L_2 \Omega}} +# $$ +# +# Measures overall reconstruction accuracy + +true_error = fem.assemble_scalar(fem.form((u - u_true) ** 2 * ufl.dx)) +true_error = np.sqrt(msh.comm.allreduce(true_error, op=MPI.SUM)) + +# Hole error +# +# $$ +# \sqrt{\vert\vert (1-m)(u-u_{true}) \vert\vert_{L_2 \Omega}} +# $$ + +hole_error = fem.assemble_scalar(fem.form((1 - m) * (u - u_true) ** 2 * ufl.dx)) +hole_error = np.sqrt(msh.comm.allreduce(hole_error, op=MPI.SUM)) + +# Image quality metric +# PSNR (peak signal to noise ratio), standard imaging metric +# since the image range is [0,1], we use +# +# $$ +# PSNR=10\log_{10}(1/MSE) +# $$ +# +# Larger PSNR means better reconstruction quality + +mse = np.mean((u.x.array - u_true.x.array) ** 2) +if mse == 0: + psnr = np.inf +else: + psnr = 10.0 * np.log10(1.0 / mse) + +# Newton Linesearch metrics +# Measure whether the nonlinear solve succeeded +# * we want a positive converged reason +# * a small final residual norm +# * a reasonable number of iterations + +snes = problem.solver +reason = snes.getConvergedReason() +iters = snes.getIterationNumber() +final_residual = snes.getFunctionNorm() + +# Objective values +# Comparing the initial objective J(f) +# with the final objective J(u) +# +# $$ +# J(v)={1\over 2}\beta \int m(v-f)^2 dx +# +\alpha \int \sqrt{||\nabla v||^2+\varepsilon^2} +# $$ +# +# A decrease in the objective show that the nonlinear optimization +# improved the damaged image undeer the smoothed TV model + +# + +objective_value = 0.5 * float(beta) * data_error**2 + float(alpha) * tv_energy +if reason > 0: + status = "converged" +else: + status = "not converged" + +u0 = fem.Function(V) +u0.x.array[:] = f.x.array.copy() + +J0_data = fem.assemble_scalar(fem.form(m * (u0 - f) ** 2 * ufl.dx)) +J0_data = msh.comm.allreduce(J0_data, op=MPI.SUM) + +J0_tv = fem.assemble_scalar( + fem.form(ufl.sqrt(ufl.inner(ufl.grad(u0), ufl.grad(u0)) + eps**2) * ufl.dx) +) +J0_tv = msh.comm.allreduce(J0_tv, op=MPI.SUM) + +J0 = 0.5 * float(beta) * J0_data + float(alpha) * J0_tv +# - + + +# Printing statments for validation and metrics +# If on main process +if msh.comm.rank == 0: + print("---Smoothed TV inpainting results---") + + print("--FEM Metrics--") + print(f"Global DOFs: {num_dofs}") + print(f"H1 seminorm error: {h1_semi_error}") + + print("--Newton Linesearch:--") + print("-Optimization:-") + print(f"Initial objective J(f): {J0:.4e}") + print(f"Final objective J(u): {objective_value:.4e}") + print(f"Relative decrease: {(J0 - objective_value) / J0:.2%}") + + print("-Solver convergence:-") + print(f"SNES iteration: {iters}") + print(f"SNES final residual norm: {final_residual:.4e}") + print(f"SNES status: {status}") + print(f"SNES converged reason: {reason}") + + print("---Reconstruction Quality:---") + print(f"Data error (known region): {data_error:.4e}") + print(f"TV seminorm: {tv_energy:.4e}") + print(f"True L2 error: {true_error:.4e}") + print(f"Hole error: {hole_error:.4e}") + print(f"PSNR: {psnr:.2f} dB") + +# ## Visualization +# We construct fields that allow us to visually asses the quality +# of the reconstruction +# $u-u_{true}$ is the global reconstruction error +# $(1-m)(u-u_{true})$ is the hole error, restriced to the missing regions + +# + +u_minus_u_true = fem.Function(V) +u_minus_u_true.x.array[:] = u.x.array - u_true.x.array + +hole_error_field = fem.Function(V) +hole_error_field.x.array[:] = (1.0 - m.x.array) * (u.x.array - u_true.x.array) +# - + +# ### FEM to matplotlib +# The solution u in FEM is represented by values at degrees +# of freedom (DOFs), not on a regular grid +# To plot in matplotlib +# 1. extract the coordinates of the DOFs +# 2. extract the function-space dofmap connectivity (triangles) +# 3. build a Triangulation object +# This allows matplotlib to render the piecewise linear FEM solution + +coords = V.tabulate_dof_coordinates() +x, y = coords[:, 0], coords[:, 1] + +triangles = V.dofmap.list +triang = mtri.Triangulation(x, y, triangles) + + +# Plotting +# We use tripcolor to plot scalar fields defined on a triangulated mesh +# shading= "flat" shows piecewise constant coloring per triangle +# which better reflects the discrete FEM representations + +# + + + +def plot_field(ax, data, title, fig, cmap="viridis", vmin=0.0, vmax=1.0): + """Plot a scalar field on a triangulated mesh.""" + im = ax.tripcolor(triang, data, shading="flat", cmap=cmap, vmin=vmin, vmax=vmax) + ax.set_title(title) + ax.set_aspect("equal") + fig.colorbar(im, ax=ax) + + +fig, axes = plt.subplots(2, 3, figsize=(12, 8)) +# $u_{true } $ground truth image +plot_field(axes[0, 0], u_true.x.array, "u_true", fig) +# m, mask with known (1) and missing (0) regions +plot_field(axes[0, 1], m.x.array, "mask", fig, cmap="gray") +# f is the damaged image +plot_field(axes[0, 2], f.x.array, "f", fig) +# u is the reconstructed image +plot_field(axes[1, 0], u.x.array, "u", fig) +# Global error +lim = np.max(np.abs(u_minus_u_true.x.array)) +# $u-u_{true}$ is the global reconstruction error +plot_field( + axes[1, 1], u_minus_u_true.x.array, "u - u_true", fig, cmap="coolwarm", vmin=-lim, vmax=lim +) +# Hole only errors +lim = np.max(np.abs(hole_error_field.x.array)) +# Hole only error restricted to the missing regions +plot_field( + axes[1, 2], + hole_error_field.x.array, + "hole-only error", + fig, + cmap="coolwarm", + vmin=-lim, + vmax=lim, +) +plt.tight_layout() +plt.show() + +# - + +# ## References +# ```{bibliography} +# :filter: cited +# :labelprefix: +# :keyprefix: tv- +# ```