Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ import ADTypes,

ndata = 2^10
ndimensions = 1
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist = Distributions.Beta(2.0, 4.0)
r = rand(data_dist, ndimensions, ndata)
r = convert.(Float32, r)

nvariables = size(r, 1)
icnf = ContinuousNormalizingFlows.ICNF(; nvariables)
Expand Down
27 changes: 13 additions & 14 deletions examples/usage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ global_logger(TerminalLogger())
using Distributions
ndata = 1024
ndimensions = 1
data_dist = Beta{Float32}(2.0f0, 4.0f0)
data_dist = Beta(2.0, 4.0)
r = rand(data_dist, ndimensions, ndata)
r = convert.(Float32, r)

## Parameters
nvariables = size(r, 1)
Expand Down Expand Up @@ -40,11 +39,11 @@ icnf = ICNF(;
nvariables = nvariables, # number of variables
naugments = naugments, # number of augmented dimensions
nconditions = 0, # number of conditioning inputs
λ₁ = 1.0f-2, # regulate flow
λ₂ = 1.0f-2, # regulate volume change
λ₃ = 1.0f-2, # regulate augmented dimensions
steer_rate = 1.0f-1, # add random noise to end of the time span
tspan = (0.0f0, 1.0f0), # time span
λ₁ = 0.01, # regulate flow
λ₂ = 0.01, # regulate volume change
λ₃ = 0.01, # regulate augmented dimensions
steer_rate = 0.1, # add random noise to end of the time span
tspan = (0.0, 1.0), # time span
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
autonomous = false, # using non-autonomous flow
Expand All @@ -54,8 +53,8 @@ icnf = ICNF(;
sol_kwargs = (;
save_everystep = false,
maxiters = typemax(Int),
reltol = 1.0f-4,
abstol = 1.0f-4,
reltol = 1.0e-4,
abstol = 1.0e-4,
alg = VCABM(; thread = Threaded()),
sensealg = InterpolatingAdjoint(;
checkpointing = true,
Expand Down Expand Up @@ -84,9 +83,9 @@ if !isfile(icnf_mach_fn)
icnf,
optimizers = (
OptimiserChain(
WeightDecay(; lambda = 1.0f-4),
ClipNorm(1.0f0, 2.0f0; throw = true),
Adam(; eta = 1.0f-3, beta = (9.0f-1, 9.99f-1), epsilon = 1.0f-8),
WeightDecay(; lambda = 1.0e-4),
ClipNorm(1.0, 2.0; throw = true),
Adam(; eta = 0.001, beta = (0.9, 0.999), epsilon = 1.0e-8),
),
),
batchsize = 1024,
Expand Down Expand Up @@ -124,8 +123,8 @@ display(res_df)
using CairoMakie
f = Figure()
ax = Axis(f[1, 1]; title = "Result")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "Actual")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "Estimated")
lines!(ax, 0.0 .. 1.0, x -> pdf(data_dist, x); label = "Actual")
lines!(ax, 0.0 .. 1.0, x -> pdf(d, vcat(x)); label = "Estimated")
axislegend(ax)
save("result-figure.svg", f)
save("result-figure.png", f)
10 changes: 5 additions & 5 deletions src/core/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct ICNF{
end

function ICNF(;
data_type::Type{<:AbstractFloat} = Float32,
data_type::Type{<:AbstractFloat} = Float64,
compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()),
inplace::Bool = false,
autonomous::Bool = false,
Expand All @@ -69,10 +69,10 @@ function ICNF(;
Lux.Dense(n_hidden => n_hidden, NNlib.softplus),
Lux.Dense(n_hidden => n_out),
),
steer_rate::AbstractFloat = convert(data_type, 1.0e-1),
λ₁::AbstractFloat = convert(data_type, 1.0e-2),
λ₂::AbstractFloat = convert(data_type, 1.0e-2),
λ₃::AbstractFloat = convert(data_type, 1.0e-2),
steer_rate::AbstractFloat = convert(data_type, 0.1),
λ₁::AbstractFloat = convert(data_type, 0.01),
λ₂::AbstractFloat = convert(data_type, 0.01),
λ₃::AbstractFloat = convert(data_type, 0.01),
basedist::Distributions.Distribution = Distributions.MvNormal(
FillArrays.Zeros{data_type}(nvariables + naugments),
FillArrays.Eye{data_type}(nvariables + naugments),
Expand Down
6 changes: 3 additions & 3 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ function CondICNFModel(;
Optimisers.WeightDecay(; lambda = convert(eltype(icnf), 1.0e-4)),
Optimisers.ClipNorm(
one(eltype(icnf)),
convert(eltype(icnf), 2.0e0);
convert(eltype(icnf), 2.0);
throw = true,
),
Optimisers.Adam(;
eta = convert(eltype(icnf), 1.0e-3),
beta = (convert(eltype(icnf), 9.0e-1), convert(eltype(icnf), 9.99e-1)),
eta = convert(eltype(icnf), 0.001),
beta = (convert(eltype(icnf), 0.9), convert(eltype(icnf), 0.999)),
epsilon = convert(eltype(icnf), 1.0e-8),
),
),
Expand Down
6 changes: 3 additions & 3 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ function ICNFModel(;
Optimisers.WeightDecay(; lambda = convert(eltype(icnf), 1.0e-4)),
Optimisers.ClipNorm(
one(eltype(icnf)),
convert(eltype(icnf), 2.0e0);
convert(eltype(icnf), 2.0);
throw = true,
),
Optimisers.Adam(;
eta = convert(eltype(icnf), 1.0e-3),
beta = (convert(eltype(icnf), 9.0e-1), convert(eltype(icnf), 9.99e-1)),
eta = convert(eltype(icnf), 0.001),
beta = (convert(eltype(icnf), 0.9), convert(eltype(icnf), 0.999)),
epsilon = convert(eltype(icnf), 1.0e-8),
),
),
Expand Down
2 changes: 1 addition & 1 deletion src/layers/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function PlanarLayer(
mapping::Pair{<:Int, <:Int},
activation::Any = identity;
init_weight::Any = WeightInitializers.glorot_uniform,
init_bias::Any = WeightInitializers.zeros32,
init_bias::Any = WeightInitializers.zeros64,
use_bias::Bool = true,
)
return PlanarLayer{
Expand Down
3 changes: 1 addition & 2 deletions test/ci_tests/regression_tests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
Test.@testset verbose = true showtiming = true failfast = false "Regression Tests" begin
ndata = 2^10
ndimensions = 1
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist = Distributions.Beta(2.0, 4.0)
r = rand(data_dist, ndimensions, ndata)
r = convert.(Float32, r)

nvariables = size(r, 1)
icnf = ContinuousNormalizingFlows.ICNF(; nvariables)
Expand Down
13 changes: 6 additions & 7 deletions test/ci_tests/smoke_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be

ndata = 4
ndimensions = 2
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist2 = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist = Distributions.Beta(2.0, 4.0)
data_dist2 = Distributions.Beta(2.0, 4.0)
if compute_mode isa ContinuousNormalizingFlows.VectorMode
r = convert.(Float32, rand(data_dist, ndimensions))
r2 = convert.(Float32, rand(data_dist2, ndimensions))
r = rand(data_dist, ndimensions)
r2 = rand(data_dist2, ndimensions)
elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode
r = convert.(Float32, rand(data_dist, ndimensions, ndata))
r2 = convert.(Float32, rand(data_dist2, ndimensions, ndata))
r = rand(data_dist, ndimensions, ndata)
r2 = rand(data_dist2, ndimensions, ndata)
end
df = DataFrames.DataFrame(permutedims(r), :auto)
df2 = DataFrames.DataFrame(permutedims(r2), :auto)
Expand Down Expand Up @@ -124,7 +124,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be

Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in
adtypes

Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))

Expand Down
4 changes: 1 addition & 3 deletions test/ci_tests/speed_tests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" begin
Test.@testset verbose = true showtiming = true failfast = false "$compute_mode" for compute_mode in
compute_modes

@show compute_mode

ndata = 2^10
ndimensions = 1
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist = Distributions.Beta(2.0, 4.0)
r = rand(data_dist, ndimensions, ndata)
r = convert.(Float32, r)

nvariables = size(r, 1)
icnf = ContinuousNormalizingFlows.ICNF(; nvariables, compute_mode)
Expand Down
12 changes: 6 additions & 6 deletions test/quality_tests/checkby_JET_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg

ndata = 4
ndimensions = 2
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist2 = Distributions.Beta{Float32}(2.0f0, 4.0f0)
data_dist = Distributions.Beta(2.0, 4.0)
data_dist2 = Distributions.Beta(2.0, 4.0)
if compute_mode isa ContinuousNormalizingFlows.VectorMode
r = convert.(Float32, rand(data_dist, ndimensions))
r2 = convert.(Float32, rand(data_dist2, ndimensions))
r = rand(data_dist, ndimensions)
r2 = rand(data_dist2, ndimensions)
elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode
r = convert.(Float32, rand(data_dist, ndimensions, ndata))
r2 = convert.(Float32, rand(data_dist2, ndimensions, ndata))
r = rand(data_dist, ndimensions, ndata)
r2 = rand(data_dist2, ndimensions, ndata)
end
nvariables = size(r, 1)

Expand Down
Loading