From 46c3261bf89d2e2084f0434a722f2f03636cf168 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 28 Apr 2026 18:17:40 +0330 Subject: [PATCH 1/3] add optional pkgs --- Project.toml | 10 +++++++++- src/ContinuousNormalizingFlows.jl | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d4d79e3d..fc224785 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -20,9 +22,11 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" @@ -40,8 +44,10 @@ DataFrames = "1" DifferentiationInterface = "0.7" Distributions = "0.25" DistributionsAD = "0.6" +FastBroadcast = "0.3, 1" FillArrays = "1" LinearAlgebra = "1" +LoopVectorization = "0.12" Lux = "1" LuxCore = "1" MLDataDevices = "1" @@ -49,9 +55,11 @@ MLJBase = "1" MLJModelInterface = "1" MLUtils = "0.4" NNlib = "0.9" +Octavian = "0.3" Optimisers = "0.4" OptimizationOptimisers = "0.3" -OrdinaryDiffEqAdamsBashforthMoulton = "1" +OrdinaryDiffEqAdamsBashforthMoulton = "1, 2" +Polyester = "0.7" Random = "1" SciMLBase = "2, 3" SciMLSensitivity = "7" diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index dca2b40b..62020a81 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -7,8 +7,10 @@ import ADTypes, DifferentiationInterface, Distributions, DistributionsAD, + FastBroadcast, FillArrays, LinearAlgebra, + LoopVectorization, Lux, LuxCore, MLDataDevices, @@ -16,9 +18,11 @@ import ADTypes, MLJModelInterface, MLUtils, NNlib, + Octavian, Optimisers, OptimizationOptimisers, OrdinaryDiffEqAdamsBashforthMoulton, + Polyester, Random, SciMLBase, SciMLSensitivity, From 61bbbc78320753b1ced86325b8434e295ab23964 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 28 Apr 2026 18:17:52 +0330 Subject: [PATCH 2/3] pass the kwargs to problems --- src/core/base_icnf.jl | 24 ++++++++++++++++-------- src/exts/mlj_ext/core_cond_icnf.jl | 3 ++- src/exts/mlj_ext/core_icnf.jl | 3 ++- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/core/base_icnf.jl b/src/core/base_icnf.jl index 85e0574b..7b97a6ef 100644 --- a/src/core/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -213,7 +213,8 @@ function inference_prob( ), vcat(xs, zrs), steer_tspan(icnf, mode), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -238,7 +239,8 @@ function inference_prob( ), vcat(xs, zrs), steer_tspan(icnf, mode), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -262,7 +264,8 @@ function inference_prob( ), vcat(xs, zrs), steer_tspan(icnf, mode), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -287,7 +290,8 @@ function inference_prob( ), vcat(xs, zrs), steer_tspan(icnf, mode), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -312,7 +316,8 @@ function generate_prob( ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -338,7 +343,8 @@ function generate_prob( ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -364,7 +370,8 @@ function generate_prob( ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), - ps, + ps; + icnf.sol_kwargs..., ) end @@ -391,7 +398,8 @@ function generate_prob( ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), - ps, + ps; + icnf.sol_kwargs..., ) end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 83b0633c..9d9592bd 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -55,7 +55,8 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) model.adtype, ), ps, - data, + data; + model.sol_kwargs..., ) res_stats = Any[] for opt in model.optimizers diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 45c35002..420d1f23 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -52,7 +52,8 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X) model.adtype, ), ps, - data, + data; + model.sol_kwargs..., ) res_stats = Any[] for opt in model.optimizers From 09aafd807e2a769e453e7fc4b51b3bf2321bf437 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 28 Apr 2026 19:32:57 +0330 Subject: [PATCH 3/3] compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fc224785..bb43dee4 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ DataFrames = "1" DifferentiationInterface = "0.7" Distributions = "0.25" DistributionsAD = "0.6" -FastBroadcast = "0.3, 1" +FastBroadcast = "1" FillArrays = "1" LinearAlgebra = "1" LoopVectorization = "0.12"