Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions 02-sgd/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
runs
checkpoints
.snakemake
82 changes: 73 additions & 9 deletions 02-sgd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,64 @@

This is a complete example demonstrating how MPoL works using simulated data.

Before starting, you should have already run the scripts in the `generate-mock-baselines` folder to produce a mock sky image and interferometer baselines in a file called `mock_data.npz`. Then, you should copy that file to this repository under `data/mock_data.npz`.
# Prerequisites

This repository assumes that you will run all scripts from this `sgd` directory (the one containing `sgd/README.md`). Some aspects of the workflow are automated with Snakemake ([`Snakefile`](Snakefile)).
Before starting, you should have already run the scripts `00` and `01` folders to produce mock baselines. Then, you will need to copy the `mock_data.npz` into this directory in a new `data` folder. For example, from within this 02 folder, run

```shell
$ mkdir data
$ cp ../01-generate-mock-baselines/data/mock_data.npz data/
```

# Installation

You can install necessary Python packages into your environment by
```shell
$ pip install -r requirements.txt
```

and then you can run the code by

```shell
snakemake -c1 all
```

# Description of Contents

This repository assumes that you will run all scripts from this `02` directory (the one containing `02-sgd/README.md`). Some aspects of the workflow are automated with Snakemake ([`Snakefile`](Snakefile)).

First, we recommend looking at [`src/load_data.py`](src/load_data.py) to see how mock visibilities $\mathcal{V}(u,v)$ are generated from the mock image and baselines.

Then, we recommend looking at [`src/plot_baselines.py`](src/plot_baselines.py) and [`src/dirty_image.py`](src/dirty_image.py) to make diagnostic plots of the baseline and a dirty image of the data, to check that everything appears as you might expect.

You can run these simple scripts using

```
$ snakemake -c1 all
```

![baselines](analysis/baselines.png)

![Dirty Beam and Image](analysis/dirty_image.png)

# RML imaging workflow

The RML imaging workflow is demonstrated in [`src/sgd.py`](src/sgd.py). We recommend looking through that file before reading the rest of this document. If you are new to PyTorch idioms, we recommend familiarizing yourself with the [PyTorch basics](https://mpol-dev.github.io/MPoL/background.html#pytorch) first.

The RML imaging workflow is not part of the Snakemake workflow, instead, one runs the script like

```
$ python src/sgd.py --epochs=5
```

Note this will just result in a short test. Run `python src/sgd.py --help` to see all available command line arguments, and see below for configurations that will result in better images.

One can visualize the results using Tensorboard via

```shell
$ tensorboard --logdir runs
```

# Validation
Since this example uses mock data, we have the advantage of knowing the true sky image. This allows us to calculate a 'validation loss' between the synthesized image and the true sky.

Expand All @@ -25,25 +73,42 @@ This approach cannot be used with real datasets, obviously, but in this case aff

If the dataset lacks many long baselines, it is unrealistic for RML to recover the native resolution of the image. In this case, we can calculate the validation score at resolutions coarser than the source image. We do this by convolving both $I_\mathrm{true}$ and $I_\mathrm{syn}$ with a 2D Gaussian described by FWHMs of $\theta_a, \theta_b$ before computing $L_\mathrm{validation}$.

# (lack of) Regularization
To demonstrate why regularization is needed for imaging workflows, try running without any:
# Example Result
Here is an example image produced with

```shell
$ mkdir checkpoints
$ python src/sgd.py --tensorboard-log-dir=runs/ent0 --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=5 --lam-ent=1e-5
```

and plotted with

```shell
$ python src/plot_image.py checkpoints/ent0.pt analysis/butterfly.png
```
python src/sgd.py --tensorboard-log-dir=runs/nolam0 --epochs=40 --log-interval=2 --save-checkpoint=checkpoints/nolam0.pt --lr 1e-2

![RML Butterfly](analysis/butterfly.png)


# More examples of training loops
To demonstrate why regularization is needed for imaging workflows, try running without any:

```shell
python src/sgd.py --tensorboard-log-dir=runs/nolam0 --epochs=10 --log-interval=2 --save-checkpoint=checkpoints/nolam0.pt --lr 1e-2
```

If run to convergence, you'll find a classic case of overfitting to the lower S/N visibilities at longer baselines / higher spatial frequencies. This manifests in the image as small splotches and/or individual pixels with very high flux concentrations. If we didn't enforce non-negative pixels by construction, this would probably manifest as high frequency "noise" similar to uniformly-weighted images.

You can spot this behavior by monitoring the training loss and the validation loss with iteration. You will see the [classic textbook signature of overfitting](https://d2l.ai/chapter_linear-regression/generalization.html#underfitting-or-overfitting): the validation loss decreases for a while but eventually turns around and increases, while the training loss monotonically decreases as it fits the signal and then eventually tries to fit all the noise. One could attempt to regularize this behavior away using early stopping. However, in practice with real data we would not have access to a validation, so we look to alternative regularization techniques.

# Maximum Entropy Regularization
## Maximum Entropy Regularization

One can obtain a decent image using Maximum Entropy Regularization. Here are a few examples that you can run, saving checkpoints and resuming from finished models. We recommend that you examine the output using Tensorboard after each run, and make adjustments accordingly.

Initial run with no entropy:

```shell
python src/sgd.py --tensorboard-log-dir=runs/exp0 --save-checkpoint=checkpoints/0.pt --lr 1e-2 --FWHM 0.05 --epochs=50
python src/sgd.py --tensorboard-log-dir=runs/exp0 --save-checkpoint=checkpoints/0.pt --lr 1e-2 --FWHM 0.05 --epochs=10
```

Resuming from previous model, and speeding up learning rate
Expand All @@ -61,5 +126,4 @@ Adding entropy regularization, and reducing learning rate slightly.
python src/sgd.py --tensorboard-log-dir=runs/ent0 --load-checkpoint=checkpoints/2.pt --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=50 --lam-ent=1e-5
```


Note that we could have started directly with the entropy regularization if we wished. The previous just demonstrates an exploratory workflow.
Note that we could have started directly with the entropy regularization if we wished. This collection just demonstrates an exploratory workflow.
6 changes: 4 additions & 2 deletions 02-sgd/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ rule dirty_image:
# python src/sgd.py --tensorboard-log-dir=runs/exp2 --load-checkpoint=checkpoints/1.pt --save-checkpoint=checkpoints/2.pt --lr 4e-1 --FWHM 0.05 --epochs=50


# python src/sgd.py --tensorboard-log-dir=runs/ent0 --load-checkpoint=checkpoints/2.pt --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=50 --lam-ent=1e-5
# python src/sgd.py --tensorboard-log-dir=runs/ent0 --save-checkpoint=checkpoints/ent0.pt --lr 1e-1 --FWHM 0.05 --epochs=50 --lam-ent=1e-5


# vary fixed FWHM and entropy regularization to find best validation score.
# vary fixed FWHM and entropy regularization to find best validation score.

#
Binary file added 02-sgd/analysis/baselines.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 02-sgd/analysis/butterfly.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 02-sgd/analysis/dirty_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion 02-sgd/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mpol
tensorboard
visread
matplotlib
matplotlib
snakemake
55 changes: 55 additions & 0 deletions 02-sgd/src/plot_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import argparse
import matplotlib.pyplot as plt
from mpol import coordinates, images
from mpol.constants import arcsec
from astropy.visualization.mpl_normalize import simple_norm

def main():
parser = argparse.ArgumentParser(description="Compare image to DSHARP image")
parser.add_argument("load_checkpoint", metavar="load-checkpoint", help="Path to checkpoint from which to resume.")
parser.add_argument("plotfile")
args = parser.parse_args()

# get the MPoL image from the checkpoint
coords = coordinates.GridCoords(cell_size=0.005, npix=1028)
checkpoint = torch.load(args.load_checkpoint, map_location=torch.device('cpu'))

# get the image cube in packed format and run through an ImageCube to unpack
icube = images.ImageCube(coords=coords)
icube(checkpoint["model_state_dict"]["icube.packed_cube"])

# remove channel dimension
mpol_img = torch.squeeze(icube.sky_cube)

lmargin = 1.0
rmargin = lmargin
XX = 5. #in
ax_width = (XX - lmargin - rmargin)
ax_height = ax_width

cax_sep = 0.05
cax_width = 0.1
tmargin = 0.05
bmargin = 1.0
YY = bmargin + ax_height + tmargin

fig = plt.figure(figsize=(XX,YY))

ax = fig.add_axes((lmargin/XX, bmargin/YY, ax_width/XX, ax_height/YY))
cax = fig.add_axes(((lmargin + ax_width + cax_sep)/XX, bmargin/YY, cax_width/XX, ax_height/YY))

im = ax.imshow(mpol_img, extent=coords.img_ext, origin="lower", cmap="inferno")
cbar = plt.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=9)
cbar.set_label(r"Jy/arcsec$^2$")

ax.set_xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
ax.set_ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")

fig.subplots_adjust(wspace=0.25)
fig.savefig(args.plotfile, dpi=300)


if __name__ == "__main__":
main()
5 changes: 3 additions & 2 deletions 02-sgd/src/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main():
parser.add_argument(
"--batch-size",
type=int,
default=2000,
default=1000,
help="input batch size for training",
)
parser.add_argument(
Expand All @@ -207,7 +207,7 @@ def main():
parser.add_argument(
"--lr",
type=float,
default=1e-3,
default=1e-2,
help="learning rate",
)
parser.add_argument("--FWHM", type=float, default=0.05, help="FWHM of Gaussian Base layer in arcseconds.")
Expand Down Expand Up @@ -257,6 +257,7 @@ def main():
vis_data.uu, vis_data.vv, vis_data.weight, vis_data.data
)

print("running on ", device)
print("total vis", len(train_dataset))

# set the batch sizes for the loaders
Expand Down
Loading