From 55db45c2bed041dd7f7874f356a6ba1bf9e91848 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 2 May 2026 18:42:37 +0330 Subject: [PATCH 1/7] upgrade to new sciml versions --- Project.toml | 6 ++---- examples/Project.toml | 2 +- examples/usage.jl | 4 ++-- src/ContinuousNormalizingFlows.jl | 1 - src/core/icnf.jl | 4 +++- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index bb43dee4..f2e219ce 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" -Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -58,13 +57,12 @@ NNlib = "0.9" Octavian = "0.3" Optimisers = "0.4" OptimizationOptimisers = "0.3" -OrdinaryDiffEqAdamsBashforthMoulton = "1, 2" +OrdinaryDiffEqAdamsBashforthMoulton = "2" Polyester = "0.7" Random = "1" -SciMLBase = "2, 3" +SciMLBase = "3" SciMLSensitivity = "7" ScientificTypesBase = "3" -Static = "1" Statistics = "1" WeightInitializers = "1" Zygote = "0.7" diff --git a/examples/Project.toml b/examples/Project.toml index d9431b5f..08b895da 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,6 +5,7 @@ ContinuousNormalizingFlows = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -12,6 +13,5 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/usage.jl b/examples/usage.jl index 0d790e0f..6d9650e2 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -21,7 +21,7 @@ n_hidden = n_in * 4 using ContinuousNormalizingFlows, Lux, OrdinaryDiffEqAdamsBashforthMoulton, - Static, + FastBroadcast, SciMLSensitivity, ADTypes, Zygote, @@ -54,7 +54,7 @@ icnf = ICNF(; maxiters = typemax(Int), reltol = sqrt(eps(Float32)), abstol = sqrt(eps(Float32)), - alg = VCABM(; thread = True()), + alg = VCABM(; thread = Threaded()), sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), ), # pass to the solver ) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 62020a81..cf061024 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -27,7 +27,6 @@ import ADTypes, SciMLBase, SciMLSensitivity, ScientificTypesBase, - Static, Statistics, WeightInitializers, Zygote diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 288af831..6170376a 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -86,7 +86,9 @@ function ICNF(; maxiters = typemax(Int), reltol = sqrt(eps(data_type)), abstol = sqrt(eps(data_type)), - alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), + alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; + thread = FastBroadcast.Threaded(), + ), sensealg = SciMLSensitivity.InterpolatingAdjoint(; checkpointing = true, autodiff = true, From 1576d3dbe2fdc0572d469a05d1a06e41197a0774 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 May 2026 15:45:08 +0330 Subject: [PATCH 2/7] add verbose --- Project.toml | 4 ++++ src/ContinuousNormalizingFlows.jl | 2 ++ src/core/icnf.jl | 1 + src/exts/mlj_ext/core_cond_icnf.jl | 1 + src/exts/mlj_ext/core_icnf.jl | 1 + 5 files changed, 9 insertions(+) diff --git a/Project.toml b/Project.toml index f2e219ce..d37403a0 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -24,6 +25,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" @@ -40,6 +42,7 @@ ADTypes = "1" ChainRulesCore = "1" ComponentArrays = "0.15" DataFrames = "1" +DiffEqBase = "7" DifferentiationInterface = "0.7" Distributions = "0.25" DistributionsAD = "0.6" @@ -56,6 +59,7 @@ MLUtils = "0.4" NNlib = "0.9" Octavian = "0.3" Optimisers = "0.4" +OptimizationBase = "5" OptimizationOptimisers = "0.3" OrdinaryDiffEqAdamsBashforthMoulton = "2" Polyester = "0.7" diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index cf061024..c5925490 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -4,6 +4,7 @@ import ADTypes, ChainRulesCore, ComponentArrays, DataFrames, + DiffEqBase, DifferentiationInterface, Distributions, DistributionsAD, @@ -20,6 +21,7 @@ import ADTypes, NNlib, Octavian, Optimisers, + OptimizationBase, OptimizationOptimisers, OrdinaryDiffEqAdamsBashforthMoulton, Polyester, diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 6170376a..cc901734 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -93,6 +93,7 @@ function ICNF(; checkpointing = true, autodiff = true, ), + verbose = DiffEqBase.DEVerbosity.All(), ), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 9d9592bd..4adb61a7 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -31,6 +31,7 @@ function CondICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), + verbose = OptimizationBase.OptimizationVerbosity.All(), ), ) return CondICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 420d1f23..d550189c 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -31,6 +31,7 @@ function ICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), + verbose = OptimizationBase.OptimizationVerbosity.All(), ), ) return ICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) From 269fced14d2dc35b641b12996afc7f0c7c5c378c Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 May 2026 15:55:01 +0330 Subject: [PATCH 3/7] fix --- Project.toml | 2 ++ src/ContinuousNormalizingFlows.jl | 1 + src/core/icnf.jl | 2 +- src/exts/mlj_ext/core_cond_icnf.jl | 2 +- src/exts/mlj_ext/core_icnf.jl | 2 +- 5 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index d37403a0..215767be 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -65,6 +66,7 @@ OrdinaryDiffEqAdamsBashforthMoulton = "2" Polyester = "0.7" Random = "1" SciMLBase = "3" +SciMLLogging = "1, 2" SciMLSensitivity = "7" ScientificTypesBase = "3" Statistics = "1" diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index c5925490..2d66c199 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -27,6 +27,7 @@ import ADTypes, Polyester, Random, SciMLBase, + SciMLLogging, SciMLSensitivity, ScientificTypesBase, Statistics, diff --git a/src/core/icnf.jl b/src/core/icnf.jl index cc901734..89828685 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -93,7 +93,7 @@ function ICNF(; checkpointing = true, autodiff = true, ), - verbose = DiffEqBase.DEVerbosity.All(), + verbose = DiffEqBase.DEVerbosity(SciMLLogging.All()), ), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 4adb61a7..421376f0 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -31,7 +31,7 @@ function CondICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), - verbose = OptimizationBase.OptimizationVerbosity.All(), + verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.All()), ), ) return CondICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index d550189c..8d1a16f0 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -31,7 +31,7 @@ function ICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), - verbose = OptimizationBase.OptimizationVerbosity.All(), + verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.All()), ), ) return ICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) From c669183a4abda42008856bce3ffbdff0dfcf3ce3 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 May 2026 16:03:01 +0330 Subject: [PATCH 4/7] switch to `Detailed` --- src/core/icnf.jl | 2 +- src/exts/mlj_ext/core_cond_icnf.jl | 2 +- src/exts/mlj_ext/core_icnf.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 89828685..a7249310 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -93,7 +93,7 @@ function ICNF(; checkpointing = true, autodiff = true, ), - verbose = DiffEqBase.DEVerbosity(SciMLLogging.All()), + verbose = DiffEqBase.DEVerbosity(SciMLLogging.Detailed()), ), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 421376f0..cd851f8c 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -31,7 +31,7 @@ function CondICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), - verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.All()), + verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.Detailed()), ), ) return CondICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 8d1a16f0..b389dcf5 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -31,7 +31,7 @@ function ICNFModel(; epochs = 300, progress = true, callback = make_opt_callback(64), - verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.All()), + verbose = OptimizationBase.OptimizationVerbosity(SciMLLogging.Detailed()), ), ) return ICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) From 54e8d15541e837155549e755324a7ec302bb760d Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 May 2026 18:22:39 +0330 Subject: [PATCH 5/7] test it --- src/core/icnf.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index a7249310..75f79f3e 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -92,6 +92,7 @@ function ICNF(; sensealg = SciMLSensitivity.InterpolatingAdjoint(; checkpointing = true, autodiff = true, + autojacvec = ifelse(inplace, true, SciMLSensitivity.ZygoteVJP()), ), verbose = DiffEqBase.DEVerbosity(SciMLLogging.Detailed()), ), From 0ea31824be3cad91ba708da562ee96610926de5c Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 May 2026 21:35:24 +0330 Subject: [PATCH 6/7] test with `autojacvec` --- examples/usage.jl | 6 +++++- src/core/icnf.jl | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/usage.jl b/examples/usage.jl index 6d9650e2..163f9773 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -55,7 +55,11 @@ icnf = ICNF(; reltol = sqrt(eps(Float32)), abstol = sqrt(eps(Float32)), alg = VCABM(; thread = Threaded()), - sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), + sensealg = InterpolatingAdjoint(; + checkpointing = true, + autodiff = true, + autojacvec = ZygoteVJP(), + ), ), # pass to the solver ) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index a7249310..63b38a0b 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -92,6 +92,15 @@ function ICNF(; sensealg = SciMLSensitivity.InterpolatingAdjoint(; checkpointing = true, autodiff = true, + autojacvec = ifelse( + inplace, + true, + ifelse( + isa(compute_mode, LuxMatrixMode), + SciMLSensitivity.ZygoteVJP(), + true, + ), + ), ), verbose = DiffEqBase.DEVerbosity(SciMLLogging.Detailed()), ), From bc0c7f94f83fd80dd560380085866b390094002e Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 3 May 2026 21:38:21 +0330 Subject: [PATCH 7/7] better --- src/core/icnf.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 63b38a0b..7214096e 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -95,11 +95,7 @@ function ICNF(; autojacvec = ifelse( inplace, true, - ifelse( - isa(compute_mode, LuxMatrixMode), - SciMLSensitivity.ZygoteVJP(), - true, - ), + ifelse(compute_mode isa LuxMatrixMode, SciMLSensitivity.ZygoteVJP(), true), ), ), verbose = DiffEqBase.DEVerbosity(SciMLLogging.Detailed()),