diff --git a/ExampleCodes/ML/PYTORCH/Source/main.cpp b/ExampleCodes/ML/PYTORCH/Source/main.cpp index 374fa6cf..924a1870 100644 --- a/ExampleCodes/ML/PYTORCH/Source/main.cpp +++ b/ExampleCodes/ML/PYTORCH/Source/main.cpp @@ -1,7 +1,12 @@ -#include // One-stop header - #include #include +#include + +#if !defined(__CUDA_ARCH__) +#include +#include +#include // One-stop header +#endif #include "myfunc.H" @@ -143,11 +148,17 @@ void main_main () BL_PROFILE_VAR("LoadPytorch",LoadPytorch); + // Keep the model output in managed memory so the GPU copy-back does not + // depend on libtorch headers in the CUDA device compilation pass. + amrex::Gpu::ManagedVector aux_out; + Real* AMREX_RESTRICT auxOutPtr = nullptr; + +#if !defined(__CUDA_ARCH__) // Load pytorch module via torch script - torch::jit::script::Module module; + torch::jit::script::Module torch_module; try { // Deserialize the ScriptModule from a file using torch::jit::load(). - module = torch::jit::load(model_filename); + torch_module = torch::jit::load(model_filename); } catch (const c10::Error& e) { amrex::Abort("Error loading the model\n"); @@ -160,7 +171,7 @@ void main_main () #ifdef AMREX_USE_CUDA torch::Device device0(torch::kCUDA); - module.to(device0); + torch_module.to(device0); amrex::Print() << "Copying model to GPU." << std::endl; // set tensor options @@ -175,6 +186,7 @@ void main_main () // EVALUATE MODEL BL_PROFILE_VAR("Eval",Eval); +#endif // loop over boxes for ( MFIter mfi(phi_in); mfi.isValid(); ++mfi ) @@ -211,6 +223,10 @@ void main_main () auxPtr[index*Nc_in + n] = phi_input(i, j, k, n); }); + aux_out.resize(ncell*Nc_out); + auxOutPtr = aux_out.dataPtr(); + +#if !defined(__CUDA_ARCH__) // create torch tensor from array at::Tensor inputs_torch = torch::from_blob(auxPtr, {ncell, Nc_in}, tensoropt); @@ -218,17 +234,15 @@ void main_main () Real eval_t_start = ParallelDescriptor::second(); // evaluate torch model - at::Tensor outputs_torch = module.forward({inputs_torch}).toTensor(); - outputs_torch = outputs_torch.to(dtype0); + at::Tensor outputs_torch = torch_module.forward({inputs_torch}).toTensor(); + outputs_torch = outputs_torch.to(dtype0).to(torch::kCPU).contiguous(); // add eval time eval_time += ParallelDescriptor::second() - eval_t_start; - // get accessor to tensor (read-only) -#ifdef AMREX_USE_CUDA - auto outputs_torch_acc = outputs_torch.packed_accessor64(); -#else - auto outputs_torch_acc = outputs_torch.accessor(); + std::copy(outputs_torch.data_ptr(), + outputs_torch.data_ptr() + ncell*Nc_out, + auxOutPtr); #endif // copy tensor to output multifab @@ -241,11 +255,13 @@ void main_main () int kk = k - bx_lo[2]; index += kk*nbox[0]*nbox[1]; #endif - phi_output(i, j, k, n) = outputs_torch_acc[index][n]; + phi_output(i, j, k, n) = auxOutPtr[index*Nc_out + n]; }); } +#if !defined(__CUDA_ARCH__) BL_PROFILE_VAR_STOP(Eval); +#endif BL_PROFILE_VAR("Post",Post);