diff --git a/src/windows/common/DeviceHostProxy.cpp b/src/windows/common/DeviceHostProxy.cpp index ac691753c..9b0ac2cde 100644 --- a/src/windows/common/DeviceHostProxy.cpp +++ b/src/windows/common/DeviceHostProxy.cpp @@ -27,6 +27,15 @@ DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId { m_devicesShutdown = false; m_git = wil::CoCreateInstance(CLSID_StdGlobalInterfaceTable, CLSCTX_INPROC_SERVER); + + // Create a job object that will terminate device host processes when this proxy is destroyed + // (i.e., when the VM shuts down). + m_jobObject.reset(CreateJobObjectW(nullptr, nullptr)); + THROW_LAST_ERROR_IF(!m_jobObject); + + JOBOBJECT_EXTENDED_LIMIT_INFORMATION jobInfo{}; + jobInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + THROW_IF_WIN32_BOOL_FALSE(SetInformationJobObject(m_jobObject.get(), JobObjectExtendedLimitInformation, &jobInfo, sizeof(jobInfo))); } GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag) @@ -152,6 +161,15 @@ try const wil::com_ptr remoteHost = DeviceHost; const wil::com_ptr unknown = remoteHost.query(); THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle)); + + // Add the device host process to the job object so it is terminated when the VM shuts down. + wil::unique_handle process(OpenProcess(PROCESS_SET_QUOTA | PROCESS_TERMINATE, FALSE, ProcessId)); + LOG_LAST_ERROR_IF_MSG(!process, "Failed to open device host process %u for job assignment", ProcessId); + if (process) + { + LOG_IF_WIN32_BOOL_FALSE(AssignProcessToJobObject(m_jobObject.get(), process.get())); + } + return S_OK; } CATCH_RETURN() diff --git a/src/windows/common/DeviceHostProxy.h b/src/windows/common/DeviceHostProxy.h index 6eb65fbe5..d862c5521 100644 --- a/src/windows/common/DeviceHostProxy.h +++ b/src/windows/common/DeviceHostProxy.h @@ -114,6 +114,8 @@ class DeviceHostProxy : public wrl::RuntimeClass m_devices; bool m_devicesShutdown; + wil::unique_handle m_jobObject; + static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll"; static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll"; }; \ No newline at end of file diff --git a/src/windows/common/SubProcess.cpp b/src/windows/common/SubProcess.cpp index 55154a8fc..edea07dc2 100644 --- a/src/windows/common/SubProcess.cpp +++ b/src/windows/common/SubProcess.cpp @@ -75,6 +75,11 @@ void SubProcess::SetShowWindow(WORD ShowWindow) m_showWindow = ShowWindow; } +void SubProcess::SetJobObject(HANDLE JobObject) +{ + m_jobObject = JobObject; +} + wsl::windows::common::helpers::unique_proc_attribute_list SubProcess::BuildProcessAttributes() { DWORD attributes = 0; @@ -93,6 +98,11 @@ wsl::windows::common::helpers::unique_proc_attribute_list SubProcess::BuildProce attributes++; } + if (m_jobObject != nullptr) + { + attributes++; + } + if (attributes == 0) { return {}; @@ -123,6 +133,13 @@ wsl::windows::common::helpers::unique_proc_attribute_list SubProcess::BuildProce list.get(), 0, PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, m_pseudoConsole, sizeof(m_pseudoConsole), nullptr, nullptr)); } + // Job object + if (m_jobObject != nullptr) + { + THROW_IF_WIN32_BOOL_FALSE(UpdateProcThreadAttribute( + list.get(), 0, PROC_THREAD_ATTRIBUTE_JOB_LIST, &m_jobObject, sizeof(m_jobObject), nullptr, nullptr)); + } + return list; } diff --git a/src/windows/common/SubProcess.h b/src/windows/common/SubProcess.h index 898930ee6..9a7013e37 100644 --- a/src/windows/common/SubProcess.h +++ b/src/windows/common/SubProcess.h @@ -40,6 +40,7 @@ class SubProcess void SetToken(HANDLE Token); void SetShowWindow(WORD Show); void SetFlags(DWORD Flag); + void SetJobObject(HANDLE JobObject); wil::unique_handle Start(); DWORD Run(DWORD Timeout = INFINITE); @@ -64,6 +65,7 @@ class SubProcess HANDLE m_stdOut = nullptr; HANDLE m_stdErr = nullptr; HPCON m_pseudoConsole = nullptr; + HANDLE m_jobObject = nullptr; std::optional m_desktopAppPolicy; std::optional m_showWindow; std::vector m_inheritHandles; diff --git a/src/windows/common/helpers.cpp b/src/windows/common/helpers.cpp index f3799f636..d8e7fbfd8 100644 --- a/src/windows/common/helpers.cpp +++ b/src/windows/common/helpers.cpp @@ -97,7 +97,8 @@ class ProcessLauncher } }; - [[nodiscard]] wil::unique_handle Launch(_In_opt_ HANDLE UserToken, _In_ bool HideWindow, _In_ bool CreateNoWindow = false) const + [[nodiscard]] wil::unique_handle Launch( + _In_opt_ HANDLE UserToken, _In_ bool HideWindow, _In_ bool CreateNoWindow = false, _In_opt_ HANDLE JobObject = nullptr) const { // If a user token was provided, create an environment block from the token. wsl::windows::common::helpers::unique_environment_block environmentBlock{nullptr}; @@ -125,6 +126,7 @@ class ProcessLauncher process.SetEnvironment(environmentBlock.get()); process.SetToken(UserToken); + process.SetJobObject(JobObject); // Launch the process. return process.Start(); @@ -137,7 +139,13 @@ class ProcessLauncher }; [[nodiscard]] wil::unique_handle LaunchWslHost( - _In_opt_ LPCGUID DistroId, _In_opt_ HANDLE InteropHandle, _In_opt_ HANDLE EventHandle, _In_opt_ HANDLE ParentHandle, _In_opt_ LPCGUID VmId, _In_opt_ HANDLE UserToken) + _In_opt_ LPCGUID DistroId, + _In_opt_ HANDLE InteropHandle, + _In_opt_ HANDLE EventHandle, + _In_opt_ HANDLE ParentHandle, + _In_opt_ LPCGUID VmId, + _In_opt_ HANDLE UserToken, + _In_opt_ HANDLE JobObject = nullptr) { // Construct the command line. // @@ -151,7 +159,7 @@ class ProcessLauncher launcher.AddHandleOption(wslhost::handle_option, InteropHandle); launcher.AddHandleOption(wslhost::event_option, EventHandle); launcher.AddHandleOption(wslhost::parent_option, ParentHandle); - return launcher.Launch(UserToken, true); + return launcher.Launch(UserToken, true, false, JobObject); } [[nodiscard]] wil::unique_handle LaunchWslRelay( @@ -162,7 +170,8 @@ class ProcessLauncher _In_opt_ std::optional Port, _In_opt_ HANDLE ExitEvent, _In_opt_ HANDLE UserToken, - _In_ LaunchWslRelayFlags Flags) + _In_ LaunchWslRelayFlags Flags, + _In_opt_ HANDLE JobObject = nullptr) { // Construct the command line. // @@ -191,7 +200,7 @@ class ProcessLauncher launcher.AddOption(wslrelay::connect_pipe_option); } - return launcher.Launch(UserToken, WI_IsFlagSet(Flags, LaunchWslRelayFlags::HideWindow)); + return launcher.Launch(UserToken, WI_IsFlagSet(Flags, LaunchWslRelayFlags::HideWindow), false, JobObject); } } // namespace @@ -548,7 +557,7 @@ bool wsl::windows::common::helpers::IsWslSupportInterfacePresent() } void wsl::windows::common::helpers::LaunchDebugConsole( - _In_ LPCWSTR PipeName, _In_ bool ConnectExistingPipe, _In_ HANDLE UserToken, _In_opt_ HANDLE LogFile, _In_ bool DisableTelemetry) + _In_ LPCWSTR PipeName, _In_ bool ConnectExistingPipe, _In_ HANDLE UserToken, _In_opt_ HANDLE LogFile, _In_ bool DisableTelemetry, _In_opt_ HANDLE JobObject) { LaunchWslRelayFlags flags{}; wil::unique_hfile pipe; @@ -576,16 +585,24 @@ void wsl::windows::common::helpers::LaunchDebugConsole( THROW_LAST_ERROR_IF(!pipe); WI_SetFlagIf(flags, LaunchWslRelayFlags::DisableTelemetry, DisableTelemetry); - wil::unique_handle info{LaunchWslRelay(wslrelay::RelayMode::DebugConsole, LogFile, nullptr, pipe.get(), {}, nullptr, UserToken, flags)}; + wil::unique_handle info{ + LaunchWslRelay(wslrelay::RelayMode::DebugConsole, LogFile, nullptr, pipe.get(), {}, nullptr, UserToken, flags, JobObject)}; } [[nodiscard]] wil::unique_handle wsl::windows::common::helpers::LaunchInteropServer( - _In_opt_ LPCGUID DistroId, _In_ HANDLE InteropHandle, _In_opt_ HANDLE EventHandle, _In_opt_ HANDLE ParentHandle, _In_opt_ LPCGUID VmId, _In_opt_ HANDLE UserToken) + _In_opt_ LPCGUID DistroId, + _In_ HANDLE InteropHandle, + _In_opt_ HANDLE EventHandle, + _In_opt_ HANDLE ParentHandle, + _In_opt_ LPCGUID VmId, + _In_opt_ HANDLE UserToken, + _In_opt_ HANDLE JobObject) { - return LaunchWslHost(DistroId, InteropHandle, EventHandle, ParentHandle, VmId, UserToken); + return LaunchWslHost(DistroId, InteropHandle, EventHandle, ParentHandle, VmId, UserToken, JobObject); } -void wsl::windows::common::helpers::LaunchKdRelay(_In_ LPCWSTR PipeName, _In_ HANDLE UserToken, _In_ int Port, _In_ HANDLE ExitEvent, _In_ bool DisableTelemetry) +void wsl::windows::common::helpers::LaunchKdRelay( + _In_ LPCWSTR PipeName, _In_ HANDLE UserToken, _In_ int Port, _In_ HANDLE ExitEvent, _In_ bool DisableTelemetry, _In_opt_ HANDLE JobObject) { // Create a new pipe server. The pipe should be: // Bi-directional: PIPE_ACCESS_DUPLEX @@ -599,15 +616,17 @@ void wsl::windows::common::helpers::LaunchKdRelay(_In_ LPCWSTR PipeName, _In_ HA LaunchWslRelayFlags flags = LaunchWslRelayFlags::ConnectPipe; WI_SetFlagIf(flags, LaunchWslRelayFlags::DisableTelemetry, DisableTelemetry); - wil::unique_handle info{LaunchWslRelay(wslrelay::RelayMode::KdRelay, nullptr, nullptr, pipe.get(), Port, ExitEvent, UserToken, flags)}; + wil::unique_handle info{ + LaunchWslRelay(wslrelay::RelayMode::KdRelay, nullptr, nullptr, pipe.get(), Port, ExitEvent, UserToken, flags, JobObject)}; } -void wsl::windows::common::helpers::LaunchPortRelay(_In_ SOCKET Socket, _In_ const GUID& VmId, _In_ HANDLE UserToken, _In_ bool DisableTelemetry) +void wsl::windows::common::helpers::LaunchPortRelay( + _In_ SOCKET Socket, _In_ const GUID& VmId, _In_ HANDLE UserToken, _In_ bool DisableTelemetry, _In_opt_ HANDLE JobObject) { LaunchWslRelayFlags flags{}; WI_SetFlagIf(flags, LaunchWslRelayFlags::DisableTelemetry, DisableTelemetry); wil::unique_handle info{LaunchWslRelay( - wslrelay::RelayMode::PortRelay, reinterpret_cast(Socket), &VmId, nullptr, {}, nullptr, UserToken, flags)}; + wslrelay::RelayMode::PortRelay, reinterpret_cast(Socket), &VmId, nullptr, {}, nullptr, UserToken, flags, JobObject)}; } void wsl::windows::common::helpers::LaunchWslSettingsOOBE(_In_ HANDLE UserToken) diff --git a/src/windows/common/helpers.hpp b/src/windows/common/helpers.hpp index 3b0cc7e37..1c5ef2322 100644 --- a/src/windows/common/helpers.hpp +++ b/src/windows/common/helpers.hpp @@ -165,7 +165,8 @@ bool IsWslOptionalComponentPresent(); bool IsWslSupportInterfacePresent(); -void LaunchDebugConsole(_In_ LPCWSTR PipeName, _In_ bool ConnectExistingPipe, _In_ HANDLE UserToken, _In_opt_ HANDLE LogFile, _In_ bool DisableTelemetry); +void LaunchDebugConsole( + _In_ LPCWSTR PipeName, _In_ bool ConnectExistingPipe, _In_ HANDLE UserToken, _In_opt_ HANDLE LogFile, _In_ bool DisableTelemetry, _In_opt_ HANDLE JobObject = nullptr); [[nodiscard]] wil::unique_handle LaunchInteropServer( _In_opt_ LPCGUID DistroId, @@ -173,11 +174,12 @@ void LaunchDebugConsole(_In_ LPCWSTR PipeName, _In_ bool ConnectExistingPipe, _I _In_opt_ HANDLE EventHandle, _In_opt_ HANDLE ParentHandle, _In_opt_ LPCGUID VmId, - _In_opt_ HANDLE UserToken = nullptr); + _In_opt_ HANDLE UserToken = nullptr, + _In_opt_ HANDLE JobObject = nullptr); -void LaunchKdRelay(_In_ LPCWSTR PipeName, _In_ HANDLE UserToken, _In_ int Port, _In_ HANDLE ExitEvent, _In_ bool DisableTelemetry); +void LaunchKdRelay(_In_ LPCWSTR PipeName, _In_ HANDLE UserToken, _In_ int Port, _In_ HANDLE ExitEvent, _In_ bool DisableTelemetry, _In_opt_ HANDLE JobObject = nullptr); -void LaunchPortRelay(_In_ SOCKET Socket, _In_ const GUID& VmId, _In_ HANDLE UserToken, _In_ bool DisableTelemetry); +void LaunchPortRelay(_In_ SOCKET Socket, _In_ const GUID& VmId, _In_ HANDLE UserToken, _In_ bool DisableTelemetry, _In_opt_ HANDLE JobObject = nullptr); void LaunchWslSettingsOOBE(_In_ HANDLE UserToken); diff --git a/src/windows/service/exe/WslCoreInstance.cpp b/src/windows/service/exe/WslCoreInstance.cpp index 4883d5a7d..7e449e88b 100644 --- a/src/windows/service/exe/WslCoreInstance.cpp +++ b/src/windows/service/exe/WslCoreInstance.cpp @@ -28,7 +28,8 @@ WslCoreInstance::WslCoreInstance( _In_ ULONG FeatureFlags, _In_ DWORD SocketTimeout, _In_ int IdleTimeout, - _Out_opt_ ULONG* ConnectPort) : + _Out_opt_ ULONG* ConnectPort, + _In_opt_ HANDLE JobObject) : LxssRunningInstance(IdleTimeout), m_featureFlags(FeatureFlags), m_instanceId(InstanceId), @@ -38,7 +39,8 @@ WslCoreInstance::WslCoreInstance( m_initializeDrvFs(DrvFsCallback), m_ntClientLifetimeId(ClientLifetimeId), m_redirectorConnectionTargets{m_configuration.Name}, - m_socketTimeout(SocketTimeout) + m_socketTimeout(SocketTimeout), + m_jobObject(JobObject) { // Establish a communication channel with the init daemon. m_initChannel = std::make_shared(InitSocket.release(), m_runtimeId, m_socketTimeout); @@ -125,7 +127,9 @@ WslCoreInstance::WslCoreInstance( DrvFsCallback, systemDistroFeatureFlags, m_socketTimeout, - IdleTimeout); + IdleTimeout, + nullptr, + JobObject); } CATCH_LOG() } @@ -419,7 +423,7 @@ void WslCoreInstance::Initialize() { const wil::unique_socket socket{wsl::windows::common::hvsocket::Connect(m_runtimeId, response.InteropPort)}; wil::unique_handle info{wsl::windows::common::helpers::LaunchInteropServer( - nullptr, reinterpret_cast(socket.get()), nullptr, nullptr, &m_runtimeId, m_userToken.get())}; + nullptr, reinterpret_cast(socket.get()), nullptr, nullptr, &m_runtimeId, m_userToken.get(), m_jobObject)}; } CATCH_LOG() } diff --git a/src/windows/service/exe/WslCoreInstance.h b/src/windows/service/exe/WslCoreInstance.h index ac2ce69fa..bf2adc2cc 100644 --- a/src/windows/service/exe/WslCoreInstance.h +++ b/src/windows/service/exe/WslCoreInstance.h @@ -67,7 +67,8 @@ class WslCoreInstance : public LxssRunningInstance _In_ ULONG FeatureFlags, _In_ DWORD SocketTimeout, _In_ int IdleTimeout, - _Out_opt_ ULONG* ConnectPort = nullptr); + _Out_opt_ ULONG* ConnectPort = nullptr, + _In_opt_ HANDLE JobObject = nullptr); virtual ~WslCoreInstance(); @@ -136,6 +137,7 @@ class WslCoreInstance : public LxssRunningInstance std::shared_ptr m_systemDistro; WSLDistributionInformation m_distributionInfo{}; DWORD m_socketTimeout{}; + HANDLE m_jobObject{}; std::thread m_oobeThread; wil::unique_event m_destroyingEvent{wil::EventOptions::ManualReset}; wil::unique_event m_oobeCompleteEvent; diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index 3c08097f5..84dc87770 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -78,6 +78,14 @@ RequiredExtraMmioSpaceForPmemFileInMb(_In_ PCWSTR FilePath) WslCoreVm::WslCoreVm(_In_ wsl::core::Config&& VmConfig) : m_vmConfig(std::move(VmConfig)), m_traceClient(m_vmConfig.EnableTelemetry) { + // Create a job object that will terminate child processes (wslhost.exe, wslrelay.exe) + // when the VM is destroyed. + m_processJobObject.reset(CreateJobObjectW(nullptr, nullptr)); + THROW_LAST_ERROR_IF(!m_processJobObject); + + JOBOBJECT_EXTENDED_LIMIT_INFORMATION jobInfo{}; + jobInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + THROW_IF_WIN32_BOOL_FALSE(SetInformationJobObject(m_processJobObject.get(), JobObjectExtendedLimitInformation, &jobInfo, sizeof(jobInfo))); } std::unique_ptr WslCoreVm::Create(_In_ const wil::shared_handle& UserToken, _In_ wsl::core::Config&& VmConfig, _In_ const GUID& VmId) @@ -309,7 +317,12 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken } wsl::windows::common::helpers::LaunchDebugConsole( - m_comPipe0.c_str(), !!m_dmesgCollector, m_restrictedToken.get(), logFile ? logFile.get() : nullptr, !m_vmConfig.EnableTelemetry); + m_comPipe0.c_str(), + !!m_dmesgCollector, + m_restrictedToken.get(), + logFile ? logFile.get() : nullptr, + !m_vmConfig.EnableTelemetry, + m_processJobObject.get()); } CATCH_LOG() } @@ -1236,7 +1249,8 @@ std::shared_ptr WslCoreVm::CreateInstanceInternal( featureFlags, m_vmConfig.DistributionStartTimeout, m_vmConfig.InstanceIdleTimeout, - ConnectPort); + ConnectPort, + m_processJobObject.get()); WI_ASSERT(!initSocket && !systemDistroSocket); @@ -1638,7 +1652,12 @@ std::wstring WslCoreVm::GenerateConfigJson() m_comPipe1 = wsl::windows::common::helpers::GetUniquePipeName(); wsl::windows::common::helpers::LaunchKdRelay( - m_comPipe1.c_str(), m_restrictedToken.get(), m_vmConfig.KernelDebugPort, m_terminatingEvent.get(), !m_vmConfig.EnableTelemetry); + m_comPipe1.c_str(), + m_restrictedToken.get(), + m_vmConfig.KernelDebugPort, + m_terminatingEvent.get(), + !m_vmConfig.EnableTelemetry, + m_processJobObject.get()); } else { @@ -1857,7 +1876,8 @@ void WslCoreVm::InitializeGuest() // N.B. The relay process is launched at medium integrity level, and its lifetime is tied to the lifetime of the utility VM. const auto result = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { const auto socket = AcceptConnection(m_vmConfig.KernelBootTimeout); - wsl::windows::common::helpers::LaunchPortRelay(socket.get(), m_runtimeId, m_restrictedToken.get(), !m_vmConfig.EnableTelemetry); + wsl::windows::common::helpers::LaunchPortRelay( + socket.get(), m_runtimeId, m_restrictedToken.get(), !m_vmConfig.EnableTelemetry, m_processJobObject.get()); }); if (FAILED(result)) diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index b3776d375..98440a2bf 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -319,6 +319,10 @@ class WslCoreVm _Guarded_by_(m_persistentMemoryLock) ULONG m_nextPersistentMemoryId = 0; std::unique_ptr m_networkingEngine; + + // Job object that terminates child processes (wslhost.exe, wslrelay.exe) + // when the VM shuts down. + wil::unique_handle m_processJobObject; }; DEFINE_ENUM_FLAG_OPERATORS(WslCoreVm::DiskStateFlags);