diff --git a/Project.toml b/Project.toml index bb43dee4..215767be 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -24,14 +25,15 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" -Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -41,6 +43,7 @@ ADTypes = "1" ChainRulesCore = "1" ComponentArrays = "0.15" DataFrames = "1" +DiffEqBase = "7" DifferentiationInterface = "0.7" Distributions = "0.25" DistributionsAD = "0.6" @@ -57,14 +60,15 @@ MLUtils = "0.4" NNlib = "0.9" Octavian = "0.3" Optimisers = "0.4" +OptimizationBase = "5" OptimizationOptimisers = "0.3" -OrdinaryDiffEqAdamsBashforthMoulton = "1, 2" +OrdinaryDiffEqAdamsBashforthMoulton = "2" Polyester = "0.7" Random = "1" -SciMLBase = "2, 3" +SciMLBase = "3" +SciMLLogging = "1, 2" SciMLSensitivity = "7" ScientificTypesBase = "3" -Static = "1" Statistics = "1" WeightInitializers = "1" Zygote = "0.7" diff --git a/examples/Project.toml b/examples/Project.toml index d9431b5f..08b895da 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,6 +5,7 @@ ContinuousNormalizingFlows = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -12,6 +13,5 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/usage.jl b/examples/usage.jl index 0d790e0f..163f9773 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -21,7 +21,7 @@ n_hidden = n_in * 4 using ContinuousNormalizingFlows, Lux, OrdinaryDiffEqAdamsBashforthMoulton, - Static, + FastBroadcast, SciMLSensitivity, ADTypes, Zygote, @@ -54,8 +54,12 @@ icnf = ICNF(; maxiters = typemax(Int), reltol = sqrt(eps(Float32)), abstol = sqrt(eps(Float32)), - alg = VCABM(; thread = True()), - sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), + alg = VCABM(; thread = Threaded()), + sensealg = InterpolatingAdjoint(; + checkpointing = true, + autodiff = true, + autojacvec = ZygoteVJP(), + ), ), # pass to the solver ) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 62020a81..2d66c199 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -4,6 +4,7 @@ import ADTypes, ChainRulesCore, ComponentArrays, DataFrames, + DiffEqBase, DifferentiationInterface, Distributions, DistributionsAD, @@ -20,14 +21,15 @@ import ADTypes, NNlib, Octavian, Optimisers, + OptimizationBase, OptimizationOptimisers, OrdinaryDiffEqAdamsBashforthMoulton, Polyester, Random, SciMLBase, + SciMLLogging, SciMLSensitivity, ScientificTypesBase, - Static, Statistics, WeightInitializers, Zygote diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 288af831..7214096e 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -86,11 +86,19 @@ function ICNF(; maxiters = typemax(Int), reltol = sqrt(eps(data_type)), abstol = sqrt(eps(data_type)), - alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), + alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; + thread = FastBroadcast.Threaded(), + ), sensealg = SciMLSensitivity.InterpolatingAdjoint(; checkpointing = true, autodiff = true, + autojacvec = ifelse( + inplace, + true, + ifelse(compute_mode isa LuxMatrixMode, SciMLSensitivity.ZygoteVJP(), true), + ), ), + verbose = DiffEqBase.DEVerbosity(SciMLLogging.Detailed()), ), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 9d9592bd..cd851f8c 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -31,6 +31,7 @@ function CondICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), + verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.Detailed()), ), ) return CondICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 420d1f23..b389dcf5 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -31,6 +31,7 @@ function ICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), + verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.Detailed()), ), ) return ICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs)