-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_distribution.py
More file actions
37 lines (26 loc) · 973 Bytes
/
plot_distribution.py
File metadata and controls
37 lines (26 loc) · 973 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
Plot the data distribution of a specific dataset.
"""
from argparse import ArgumentParser
from pathlib import Path
import matplotlib.pyplot as plt
from nfs import DATASETS_2D
from nfs.datasets import FlowDataset
ASSETS_DIR = Path(__file__).parent / "assets"
ASSETS_DIR.mkdir(parents=True, exist_ok=True)
def main() -> None:
parser = ArgumentParser(description=__doc__)
parser.add_argument("--dataset", type=str, choices=DATASETS_2D.keys(), required=True)
parser.add_argument("--num-samples", type=int, default=10000)
args = parser.parse_args()
dataset: FlowDataset = DATASETS_2D[args.dataset]()
x = dataset.sample(args.num_samples)
plt.figure(figsize=(4, 4))
plt.scatter(x[:, 0], x[:, 1], s=5, alpha=0.5)
plt.axis("equal")
plt.title(f"{args.dataset}: {args.num_samples:,} samples")
path = ASSETS_DIR / f"{args.dataset}.png"
plt.savefig(path)
print(f"Saved to {path}")
if __name__ == "__main__":
main()