Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions ExampleCodes/ML/PYTORCH/Source/main.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#include<torch/script.h> // One-stop header

#include <AMReX_PlotFileUtil.H>
#include <AMReX_ParmParse.H>
#include <algorithm>

#if !defined(__CUDA_ARCH__)
#include <ATen/core/ivalue.h>
#include <torch/torch.h>
#include <torch/script.h> // One-stop header
#endif

#include "myfunc.H"

Expand Down Expand Up @@ -143,11 +148,17 @@ void main_main ()

BL_PROFILE_VAR("LoadPytorch",LoadPytorch);

// Keep the model output in managed memory so the GPU copy-back does not
// depend on libtorch headers in the CUDA device compilation pass.
amrex::Gpu::ManagedVector<Real> aux_out;
Real* AMREX_RESTRICT auxOutPtr = nullptr;

#if !defined(__CUDA_ARCH__)
// Load pytorch module via torch script
torch::jit::script::Module module;
torch::jit::script::Module torch_module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(model_filename);
torch_module = torch::jit::load(model_filename);
}
catch (const c10::Error& e) {
amrex::Abort("Error loading the model\n");
Expand All @@ -160,7 +171,7 @@ void main_main ()

#ifdef AMREX_USE_CUDA
torch::Device device0(torch::kCUDA);
module.to(device0);
torch_module.to(device0);
amrex::Print() << "Copying model to GPU." << std::endl;

// set tensor options
Expand All @@ -175,6 +186,7 @@ void main_main ()
// EVALUATE MODEL

BL_PROFILE_VAR("Eval",Eval);
#endif

// loop over boxes
for ( MFIter mfi(phi_in); mfi.isValid(); ++mfi )
Expand Down Expand Up @@ -211,24 +223,26 @@ void main_main ()
auxPtr[index*Nc_in + n] = phi_input(i, j, k, n);
});

aux_out.resize(ncell*Nc_out);
auxOutPtr = aux_out.dataPtr();

#if !defined(__CUDA_ARCH__)
// create torch tensor from array
at::Tensor inputs_torch = torch::from_blob(auxPtr, {ncell, Nc_in}, tensoropt);

// store the current time so we can later compute total eval time.
Real eval_t_start = ParallelDescriptor::second();

// evaluate torch model
at::Tensor outputs_torch = module.forward({inputs_torch}).toTensor();
outputs_torch = outputs_torch.to(dtype0);
at::Tensor outputs_torch = torch_module.forward({inputs_torch}).toTensor();
outputs_torch = outputs_torch.to(dtype0).to(torch::kCPU).contiguous();

// add eval time
eval_time += ParallelDescriptor::second() - eval_t_start;

// get accessor to tensor (read-only)
#ifdef AMREX_USE_CUDA
auto outputs_torch_acc = outputs_torch.packed_accessor64<Real,2>();
#else
auto outputs_torch_acc = outputs_torch.accessor<Real,2>();
std::copy(outputs_torch.data_ptr<Real>(),
outputs_torch.data_ptr<Real>() + ncell*Nc_out,
auxOutPtr);
#endif

// copy tensor to output multifab
Expand All @@ -241,11 +255,13 @@ void main_main ()
int kk = k - bx_lo[2];
index += kk*nbox[0]*nbox[1];
#endif
phi_output(i, j, k, n) = outputs_torch_acc[index][n];
phi_output(i, j, k, n) = auxOutPtr[index*Nc_out + n];
});
}

#if !defined(__CUDA_ARCH__)
BL_PROFILE_VAR_STOP(Eval);
#endif

BL_PROFILE_VAR("Post",Post);

Expand Down
Loading