From 2fdb4fb880f327aa3e38e98111cfc45fc83fa0eb Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 22:03:26 +0200 Subject: [PATCH 1/8] Adding possibility to enable MT while importing a TRestDataSet --- source/framework/core/inc/TRestDataSet.h | 2 +- source/framework/core/src/TRestDataSet.cxx | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/framework/core/inc/TRestDataSet.h b/source/framework/core/inc/TRestDataSet.h index 94f0a490d..28d185f7a 100644 --- a/source/framework/core/inc/TRestDataSet.h +++ b/source/framework/core/inc/TRestDataSet.h @@ -169,8 +169,8 @@ class TRestDataSet : public TRestMetadata { inline void SetQuantity(const std::map& quantity) { fQuantity = quantity; } TRestDataSet& operator=(TRestDataSet& dS); - void Import(const std::string& fileName); void Import(std::vector fileNames); + void Import(const std::string& fileName, bool enableMT = true); void Export(const std::string& filename); ROOT::RDF::RNode MakeCut(const TRestCut* cut); diff --git a/source/framework/core/src/TRestDataSet.cxx b/source/framework/core/src/TRestDataSet.cxx index 430aba8aa..618f40667 100644 --- a/source/framework/core/src/TRestDataSet.cxx +++ b/source/framework/core/src/TRestDataSet.cxx @@ -890,7 +890,7 @@ TRestDataSet& TRestDataSet::operator=(TRestDataSet& dS) { /// it import metadata info from the previous dataSet /// while it opens the analysis tree /// -void TRestDataSet::Import(const std::string& fileName) { +void TRestDataSet::Import(const std::string& fileName, bool enableMT) { if (TRestTools::GetFileNameExtension(fileName) != "root") { RESTError << "Datasets can only be imported from root files" << RESTendl; return; @@ -918,7 +918,7 @@ void TRestDataSet::Import(const std::string& fileName) { return; } - ROOT::EnableImplicitMT(); + if(enableMT)ROOT::EnableImplicitMT(); fDataSet = ROOT::RDataFrame("AnalysisTree", fileName); From 7c132123b333aa116cebe8403e23131385e3b30d Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 22:07:17 +0200 Subject: [PATCH 2/8] Adding new function to open and update output file --- source/framework/core/inc/TRestRun.h | 1 + source/framework/core/src/TRestRun.cxx | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/source/framework/core/inc/TRestRun.h b/source/framework/core/inc/TRestRun.h index ec4979bf4..9f69deb00 100644 --- a/source/framework/core/inc/TRestRun.h +++ b/source/framework/core/inc/TRestRun.h @@ -92,6 +92,7 @@ class TRestRun : public TRestMetadata { TFile* MergeToOutputFile(std::vector filefullnames, std::string outputfilename = ""); TFile* FormOutputFile(); TFile* UpdateOutputFile(); + TFile* OpenAndUpdateOutputFile(); void PassOutputFile() { fOutputFile = fInputFile; diff --git a/source/framework/core/src/TRestRun.cxx b/source/framework/core/src/TRestRun.cxx index 0069fb8dd..d483d879e 100644 --- a/source/framework/core/src/TRestRun.cxx +++ b/source/framework/core/src/TRestRun.cxx @@ -1079,6 +1079,18 @@ TFile* TRestRun::UpdateOutputFile() { return nullptr; } +/////////////////////////////////////////////// +/// \brief Open and update output file in case is closed +/// +TFile* TRestRun::OpenAndUpdateOutputFile() { + if (fOutputFile == nullptr) { + fOutputFile = TFile::Open(fOutputFileName,"UPDATE"); + } + + return UpdateOutputFile(); + +} + /////////////////////////////////////////////// /// \brief Write this object into TFile and add a new entry in database /// From 56fb7d08fffe7d2afd1e1fe1b6a0c02f8410cbcc Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 22:08:59 +0200 Subject: [PATCH 3/8] Avoid use of new and delete in TFileMerger --- source/framework/core/src/TRestRun.cxx | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/source/framework/core/src/TRestRun.cxx b/source/framework/core/src/TRestRun.cxx index d483d879e..9b7a88ce0 100644 --- a/source/framework/core/src/TRestRun.cxx +++ b/source/framework/core/src/TRestRun.cxx @@ -984,24 +984,25 @@ TString TRestRun::FormFormat(const TString& FilenameFormat) { TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilename) { RESTDebug << "TRestRun::FormOutputFile. target : " << outputfilename << RESTendl; string filename; - TFileMerger* m = new TFileMerger(false); + + TFileMerger m (false); if (outputfilename == "") { filename = fOutputFileName; RESTInfo << "Creating file : " << filename << RESTendl; - m->OutputFile(filename.c_str(), "RECREATE"); + m.OutputFile(filename.c_str(), "RECREATE"); } else { filename = outputfilename; - RESTInfo << "Updating file : " << filename << RESTendl; - m->OutputFile(filename.c_str(), "UPDATE"); + RESTInfo << "Creating file : " << filename << RESTendl; + m.OutputFile(filename.c_str(), "UPDATE"); } RESTDebug << "TRestRun::FormOutputFile. Starting to add files" << RESTendl; for (unsigned int i = 0; i < filenames.size(); i++) { - m->AddFile(filenames[i].c_str(), false); + m.AddFile(filenames[i].c_str(), false); } - if (m->Merge()) { + if (m.Merge()) { for (unsigned int i = 0; i < filenames.size(); i++) { remove(filenames[i].c_str()); } @@ -1011,8 +1012,6 @@ TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilena exit(1); } - delete m; - // we rename the created output file fOutputFileName = FormFormat(filename); rename(filename.c_str(), fOutputFileName); From 7cbb80cbf75540dde13b844f7bf41281332e0925 Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 22:16:24 +0200 Subject: [PATCH 4/8] Adding new analysis TRestDataSetTMVA to compute root TMVA over a dataSet --- .../framework/analysis/inc/TRestDataSetTMVA.h | 90 +++++ .../analysis/src/TRestDataSetTMVA.cxx | 368 ++++++++++++++++++ 2 files changed, 458 insertions(+) create mode 100644 source/framework/analysis/inc/TRestDataSetTMVA.h create mode 100644 source/framework/analysis/src/TRestDataSetTMVA.cxx diff --git a/source/framework/analysis/inc/TRestDataSetTMVA.h b/source/framework/analysis/inc/TRestDataSetTMVA.h new file mode 100644 index 000000000..994c0d9db --- /dev/null +++ b/source/framework/analysis/inc/TRestDataSetTMVA.h @@ -0,0 +1,90 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +#ifndef REST_TRestDataSetTMVA +#define REST_TRestDataSetTMVA + +#include "TH1F.h" +#include "TRestCut.h" +#include "TRestMetadata.h" +#include "TMVA/Types.h" + +/// This class is meant to evaluate several TMVA methods in datasets +class TRestDataSetTMVA : public TRestMetadata { + private: + /// Name of the output file + std::string fOutputFileName = ""; //< + + /// Name of the signal dataSet + std::string fDataSetSignal = ""; //< + + /// Name of the background dataset + std::string fDataSetBackground = ""; //< + + /// Name of the output path for the xml files + std::string fOutputPath = ""; //< + + /// Vector containing different obserbable names + std::vector fObsName; //< + + /// Add method to compute TMVA, https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf for more details + std::vector > fMethod; //< + + /// Cuts over background dataset for PDF selection + TRestCut* fBackgroundCut = nullptr; //< + + /// Cuts over signal dataset for PDF selection + TRestCut* fSignalCut = nullptr; //< + + /// If true display ROC curve after evaluating all methods + bool fDrawROCCurve = true; //< + + /// Map with supported TMVA methods, please add more if something is missing + const std::map fMethodMap = { //< + {"Likelihood", TMVA::Types::kLikelihood }, // Likelihood ("naive Bayes estimator") + {"LikelihoodKDE", TMVA::Types::kLikelihood }, // Use a kernel density estimator to approximate the PDFs + {"Fisher", TMVA::Types::kFisher }, // Fisher discriminant (same as LD) + {"BDT", TMVA::Types::kBDT }, //Boosted Decision Trees + {"MLP", TMVA::Types::kMLP } //Multi-Layer Perceptron (Neural Network) + }; + + void Initialize() override; + void InitFromConfigFile() override; + + public: + + void PrintMetadata() override; + + void ComputeTMVA(); + + inline void SetDataSetSignal(const std::string& dSName) { fDataSetSignal = dSName; } + inline void SetDataSetBackground(const std::string& dSName) { fDataSetBackground = dSName; } + inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; } + inline void SetOutputPath(const std::string& outPath) { fOutputPath = outPath; } + + TRestDataSetTMVA(); + TRestDataSetTMVA(const char* configFilename, std::string name = ""); + ~TRestDataSetTMVA(); + + ClassDefOverride(TRestDataSetTMVA, 1); +}; +#endif diff --git a/source/framework/analysis/src/TRestDataSetTMVA.cxx b/source/framework/analysis/src/TRestDataSetTMVA.cxx new file mode 100644 index 000000000..b6ea618fb --- /dev/null +++ b/source/framework/analysis/src/TRestDataSetTMVA.cxx @@ -0,0 +1,368 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +///////////////////////////////////////////////////////////////////////// +/// TRestDataSetTMVA is meant to evaluate different TMVA methods in datasets. +/// For more information about TMVA, check https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf +/// So far, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP methods are +/// supported. TMVA requires a signal and a background dataset from which the +/// different TMVA methods are computed. The different methods are evaluated +/// in a set of observables that are provided in the RML file. Different cuts +/// can be performed in either the signal or the background datasets prior to +/// the TMVA evaluation. The output of this class is a root file which contains +/// a signal and a background tree with the cuts applied and the different observables +/// that are generated with the TMVA analysis. In addition, a folder is created +/// with different xml files that contain the output of the TMVA evaluation that +/// can be used to compute the TMVA classification via TRestDataSetTMVAClassification. +/// +/// A summary of the basic parameters is described below: +/// * **outputFileName**: Name of the output file +/// * **dataSetSignal**: Name of the dataset file containing the signal +/// * **dataSetBackground**: Name of the dataset file containing the background +/// * **outputPath**: Name of the output path with the evaluation results +/// * **drawROCCurve**: If true display the ROC curve for the evaluation of all methods +/// +/// The different observables for the TMVA analysis can be added with the following key: +/// \code +/// +/// \endcode +/// +/// * **name**: Name of the observable be computed +/// +/// The different signal and background cuts can be added awith the following key: +/// \code +/// +/// +/// +/// \endcode +/// +/// Where the cut name (e.g. ParamCut or Energy cut from above) have to be defined inside +/// the RML file, e.g.: +/// \code +/// +/// +/// +/// +/// +/// +/// +/// +/// \endcode +/// +/// Please, check TRestCut class for more info. +/// +/// The different TMVA methods can be added wit the following key: +/// \code +/// +/// \endcode +/// The different parameters for adding TMVA methods are described below: +/// * **name**: Name of the TMVA method, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP +/// are supported so far. +/// * **parameters**: String parameters to be used in the TMVA method, for more information +/// check https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf +/// +/// ### Examples +/// Example of RML config file: +/// \code +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// \endcode +/// +/// Example to perform TMVA evaluation using restRoot: +/// \code +/// [0] TRestDataSetTMVA tmva("tmva.rml"); +/// [1] tmva.SetDataSetSignal("DataSetSignal.root"); +/// [2] tmva.SetDataSetBackground("DataSetBackground.root"); +/// [3] tmva.SetOutputFileName("MyDataSetEvaluation.root"); +/// [4] tmva.SetOutputPath("MyDataSetFiles"); +/// [5] tmva.ComputeTMVA(); +/// \endcode +/// +/// In addition it is possible to display TMVA results after evaluating all methods, +/// using root or restRoot; +/// \code +/// [0] TMVA::TMVAGui("MyDataSetEvaluation.root") +/// \endcode +/// +///---------------------------------------------------------------------- +/// +/// REST-for-Physics - Software for Rare Event Searches Toolkit +/// +/// History of developments: +/// +/// 2023-05: First implementation of TRestDataSetTMVA +/// JuanAn Garcia +/// +/// \class TRestDataSetTMVA +/// \author: JuanAn Garcia e-mail: juanangp@unizar.es +/// +///
+/// + +#include "TRestDataSetTMVA.h" + +#include "TRestDataSet.h" + +#include "TMVA/CrossValidation.h" +#include "TMVA/DataLoader.h" +#include "ROOT/RDFHelpers.hxx" +#include "TMVA/Factory.h" +#include "TMVA/Tools.h" +#include "TMVA/TMVAGui.h" + +ClassImp(TRestDataSetTMVA); + +/////////////////////////////////////////////// +/// \brief Default constructor +/// +TRestDataSetTMVA::TRestDataSetTMVA() { Initialize(); } + +///////////////////////////////////////////// +/// \brief Constructor loading data from a config file +/// +/// If no configuration path is defined using TRestMetadata::SetConfigFilePath +/// the path to the config file must be specified using full path, absolute or +/// relative. +/// +/// The default behaviour is that the config file must be specified with +/// full path, absolute or relative. +/// +/// \param configFilename A const char* that defines the RML filename. +/// \param name The name of the metadata section. It will be used to find the +/// corresponding TRestDataSetTMVA section inside the RML. +/// +TRestDataSetTMVA::TRestDataSetTMVA(const char* configFilename, std::string name) + : TRestMetadata(configFilename) { + LoadConfigFromFile(fConfigFileName, name); + Initialize(); + + if (GetVerboseLevel() >= TRestStringOutput::REST_Verbose_Level::REST_Info) PrintMetadata(); +} + +/////////////////////////////////////////////// +/// \brief Default destructor +/// +TRestDataSetTMVA::~TRestDataSetTMVA() {} + +/////////////////////////////////////////////// +/// \brief Function to initialize input/output event members and define +/// the section name +/// +void TRestDataSetTMVA::Initialize() { SetSectionName(this->ClassName()); } + +/////////////////////////////////////////////// +/// \brief Function to initialize some variables from +/// configfile +/// +void TRestDataSetTMVA::InitFromConfigFile() { + Initialize(); + TRestMetadata::InitFromConfigFile(); + + TiXmlElement* obsDefinition = GetElement("observable"); + while (obsDefinition != nullptr) { + std::string obsName = GetFieldValue("name", obsDefinition); + if (obsName.empty() || obsName == "Not defined") { + RESTError << "< observable variable key does not contain a name!" << RESTendl; + exit(1); + } else { + fObsName.push_back(obsName); + } + + obsDefinition = GetNextElement(obsDefinition); + } + + TiXmlElement* cutele = GetElement("addBackgroundCut"); + while (cutele != nullptr) { + std::string cutName = GetParameter("name", cutele, ""); + if (!cutName.empty()) { + if (fBackgroundCut == nullptr) { + fBackgroundCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName); + } else { + fBackgroundCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName)); + } + } + cutele = GetNextElement(cutele); + } + + cutele = GetElement("addSignalCut"); + while (cutele != nullptr) { + std::string cutName = GetParameter("name", cutele, ""); + if (!cutName.empty()) { + if (fSignalCut == nullptr) { + fSignalCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName); + } else { + fSignalCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName)); + } + } + cutele = GetNextElement(cutele); + } + + TiXmlElement* method = GetElement("addMethod"); + while (method != nullptr) { + std::string name = GetParameter("name", method, ""); + std::string params = GetParameter("parameters", method, ""); + if (name.empty() || params.empty()) { + RESTWarning << "Empty method" << RESTendl; + } else { + fMethod.push_back(std::make_pair(name, params)); + } + method = GetNextElement(method); + } + + if (fObsName.empty() ) { + RESTError << "No observables provided, exiting..." << RESTendl; + exit(1); + } + + if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); + +} + +///////////////////////////////////////////// +/// \brief This function computes the TMVA using +/// the different methods provided via config file +/// and the signal and background dataSets. The results +/// are stored in an output root file and a folder. Note +/// that it doesn't provide any usable dataset, just standard +/// root files. +/// +void TRestDataSetTMVA::ComputeTMVA() { + + if(fOutputFileName.empty() || fOutputPath.empty() || fDataSetSignal.empty() || fDataSetBackground.empty() ){ + RESTError <<"Empty output file name, path, signal or background files "<Get("Signal"); + auto bckTree = outputFile->Get("Background"); + + TMVA::Factory factory("TMVA_Classification", outputFile, + "!V:ROC:!Silent:Color:AnalysisType=Classification" ); + + TMVA::DataLoader loader (fOutputPath); + + // Add observables for the evaluation + for(const auto &obs : fObsName)loader.AddVariable(obs); + + loader.AddSignalTree ( signalTree, 1.0); + loader.AddBackgroundTree( bckTree, 1.0); + loader.PrepareTrainingAndTestTree( "","", + ":SplitMode=Random" + ":NormMode=NumEvents" + ":!V"); + + // Add different TMVA methods + for(const auto &[name, params] : fMethod){ + auto it = fMethodMap.find(name); + if(it == fMethodMap.end() ){ + RESTWarning << "Method " << name << " not supported "<second << " " << params << std::endl; + factory.BookMethod(&loader, it->second, name.c_str(), params.c_str()); + } + + // Train, test and evaluate all methods + factory.TrainAllMethods(); + factory.TestAllMethods(); + factory.EvaluateAllMethods(); + + // Draw ROC curve + if (fDrawROCCurve && gApplication != nullptr && gApplication->IsRunning()){ + auto c1 = factory.GetROCCurve(&loader); + c1->Draw(); + } + + outputFile->Close(); + +} + +///////////////////////////////////////////// +/// \brief Prints on screen the information about the metadata members of TRestDataSetTMVA +/// +void TRestDataSetTMVA::PrintMetadata() { + TRestMetadata::PrintMetadata(); + + RESTMetadata << " Observables to compute: " << RESTendl; + for (const auto & obs : fObsName) { + RESTMetadata << obs << RESTendl; + } + RESTMetadata << " TMVA Methods " << RESTendl; + for(const auto &[name, params] : fMethod){ + RESTMetadata << name << " "<< params << RESTendl; + } + RESTMetadata << "----" << RESTendl; +} From dc365b76840a38d9ab0c034225ca2890f78521db Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 22:17:13 +0200 Subject: [PATCH 5/8] Adding new class TRestDataSetTMVAClassification to compute TMVA score over a TRestDataSet --- .../inc/TRestDataSetTMVAClassification.h | 70 +++++ .../src/TRestDataSetTMVAClassification.cxx | 295 ++++++++++++++++++ 2 files changed, 365 insertions(+) create mode 100644 source/framework/analysis/inc/TRestDataSetTMVAClassification.h create mode 100644 source/framework/analysis/src/TRestDataSetTMVAClassification.cxx diff --git a/source/framework/analysis/inc/TRestDataSetTMVAClassification.h b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h new file mode 100644 index 000000000..fe558ee4d --- /dev/null +++ b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h @@ -0,0 +1,70 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +#ifndef REST_TRestDataSetTMVAClassification +#define REST_TRestDataSetTMVAClassification + +#include "TH1F.h" +#include "TRestCut.h" +#include "TRestMetadata.h" + +/// This class is meant to classify a given dataset with a particular TMVA method +class TRestDataSetTMVAClassification : public TRestMetadata { + private: + /// Name of the output file + std::string fOutputFileName = ""; //< + + /// Name of the dataSet to classify + std::string fDataSetName = ""; //< + + /// Name of the TMVA method + std::string fTmvaMethod = ""; //< + + /// Name of the TMVA weights file + std::string fTmvaFile = ""; //< + + /// Vector containing different obserbable names + std::vector fObsName; //< + + /// Cuts over the dataset for PDF selection + TRestCut* fCut = nullptr; //< + + void Initialize() override; + void InitFromConfigFile() override; + + public: + void PrintMetadata() override; + + void ClassifyTMVA(); + + inline void SetDataSet(const std::string& dSName) { fDataSetName = dSName; } + inline void SetTMVAMethod(const std::string& method) { fTmvaMethod = method; } + inline void SetTMVAFile(const std::string& file) { fTmvaFile = file; } + inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; } + + TRestDataSetTMVAClassification(); + TRestDataSetTMVAClassification(const char* configFilename, std::string name = ""); + ~TRestDataSetTMVAClassification(); + + ClassDefOverride(TRestDataSetTMVAClassification, 1); +}; +#endif diff --git a/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx new file mode 100644 index 000000000..9c64c3100 --- /dev/null +++ b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx @@ -0,0 +1,295 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +///////////////////////////////////////////////////////////////////////// +/// TRestDataSetTMVAClassification performs the classification of a given +/// dataSet using as input the results of the TMVA evaluation methods +/// generated using TRestDataSetTMVA. Note that the observables used on +/// TRestDataSetTMVA and TRestDataSetTMVA needs to match. This class generates +/// as output a dataset with a new observable which is defined using the name of +/// the TMVA method that has been used to classify the dataset. Only one TMVA +/// method is classified in ClassifyTMVA. An output dataset is generated by +/// definining a new observable with the TMVA method e.g. BDT_score +/// +/// A summary of the basic parameters is described below: +/// * **dataSetName**: Name of the dataSet to be classified +/// * **tmvaFile**: Name of the xml input file with the tmva weigths +/// * **tmvaMethod**: Name of the TMVA method used to classify +/// * **outputFileName**: Name of the output dataset +/// +/// +/// The different observables for the TMVA classification can be added with the following key: +/// \code +/// +/// \endcode +/// * **name**: Name of the observable be computed +/// +/// Note that the observable names has to match the ones using for the evaluation of +/// a particular TMVA method +/// +/// Different cuts over the dataset can be added with the following key: +/// \code +/// +/// \endcode +/// +/// ### Examples +/// Example of RML config file: +/// \code +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// \endcode +/// +/// Example of TRestDataSetTMVAClassification using restRoot: +/// \code +/// TRestDataSetTMVAClassification tmva("tmva.rml"); +/// tmva.SetDataSet("MyDataSet.root"); +/// tmva.SetOutputFileName("MyClassifiedDataSet.root"); +/// tmva.ClassifyTMVA(); +/// \endcode +/// +///---------------------------------------------------------------------- +/// +/// REST-for-Physics - Software for Rare Event Searches Toolkit +/// +/// History of developments: +/// +/// 2023-03: First implementation of TRestDataSetTMVAClassification +/// JuanAn Garcia +/// +/// \class TRestDataSetTMVAClassification +/// \author: JuanAn Garcia e-mail: juanangp@unizar.es +/// +///
+/// + +#include "TRestDataSetTMVAClassification.h" + +#include "TRestDataSet.h" + +#include "TMVA/CrossValidation.h" +#include "TMVA/DataLoader.h" +#include "TMVA/RReader.hxx" +#include "ROOT/RDFHelpers.hxx" +#include "TMVA/RTensorUtils.hxx" +#include "TMVA/RInferenceUtils.hxx" +#include "TMVA/Factory.h" +#include "TMVA/Tools.h" + +ClassImp(TRestDataSetTMVAClassification); + +/////////////////////////////////////////////// +/// \brief Default constructor +/// +TRestDataSetTMVAClassification::TRestDataSetTMVAClassification() { Initialize(); } + +///////////////////////////////////////////// +/// \brief Constructor loading data from a config file +/// +/// If no configuration path is defined using TRestMetadata::SetConfigFilePath +/// the path to the config file must be specified using full path, absolute or +/// relative. +/// +/// The default behaviour is that the config file must be specified with +/// full path, absolute or relative. +/// +/// \param configFilename A const char* that defines the RML filename. +/// \param name The name of the metadata section. It will be used to find the +/// corresponding TRestDataSetTMVAClassification section inside the RML. +/// +TRestDataSetTMVAClassification::TRestDataSetTMVAClassification(const char* configFilename, std::string name) + : TRestMetadata(configFilename) { + LoadConfigFromFile(fConfigFileName, name); + Initialize(); + + if (GetVerboseLevel() >= TRestStringOutput::REST_Verbose_Level::REST_Info) PrintMetadata(); +} + +/////////////////////////////////////////////// +/// \brief Default destructor +/// +TRestDataSetTMVAClassification::~TRestDataSetTMVAClassification() {} + +/////////////////////////////////////////////// +/// \brief Function to initialize input/output event members and define +/// the section name +/// +void TRestDataSetTMVAClassification::Initialize() { SetSectionName(this->ClassName()); } + +/////////////////////////////////////////////// +/// \brief Function to initialize some variables from +/// configfile +/// +void TRestDataSetTMVAClassification::InitFromConfigFile() { + Initialize(); + TRestMetadata::InitFromConfigFile(); + + TiXmlElement* obsDefinition = GetElement("observable"); + while (obsDefinition != nullptr) { + std::string obsName = GetFieldValue("name", obsDefinition); + if (obsName.empty() || obsName == "Not defined") { + RESTError << "< observable variable key does not contain a name!" << RESTendl; + exit(1); + } else { + fObsName.push_back(obsName); + } + + obsDefinition = GetNextElement(obsDefinition); + } + + TiXmlElement* cutele = GetElement("addCut"); + while (cutele != nullptr) { + std::string cutName = GetParameter("name", cutele, ""); + if (!cutName.empty()) { + if (fCut == nullptr) { + fCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName); + } else { + fCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName)); + } + } + cutele = GetNextElement(cutele); + } + + if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); +} + +///////////////////////////////////////////// +/// \brief This function computes the TMVA classification +/// for a given dataSet. It requires a xml file with weigths +/// from the output of TRestDataSetTMVA to perform the +/// clasification for a given set of observables. This function +/// defines a new observable with the score of the TMVA method +/// provided in the input file that can be used for further +/// signal and background discrimination. +/// +void TRestDataSetTMVAClassification::ClassifyTMVA() { + PrintMetadata(); + + if (fObsName.empty() ) { + RESTError << "No observables provided, exiting..." << RESTendl; + exit(1); + } + + TMVA::Reader reader ( "!Color:!Silent" ); + std::vector var (fObsName.size()); + + // Add variables to the reader + for(unsigned int i=0; i &val) { return reader.EvaluateMVA(val, tmvaMethod.c_str()); }; + + TRestDataSet dataSet; + dataSet.Import(fDataSetName); + + auto df = dataSet.MakeCut(fCut); + + std::string obsName = fTmvaMethod +"_score"; + + // Ugly but cannot pass vector size to ROOT::RDF::PassAsVec + switch (fObsName.size()){ + case 1: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<1, double>(eval), fObsName); + break; + case 2: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<2, double>(eval), fObsName); + break; + case 3: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<3, double>(eval), fObsName); + break; + case 4: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<4, double>(eval), fObsName); + break; + case 5: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<5, double>(eval), fObsName); + break; + case 6: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<6, double>(eval), fObsName); + break; + case 7: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<7, double>(eval), fObsName); + break; + case 8: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<8, double>(eval), fObsName); + break; + case 9: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<9, double>(eval), fObsName); + break; + case 10: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<10, double>(eval), fObsName); + break; + case 11: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<11, double>(eval), fObsName); + break; + default: + RESTError<<"Number of observables "<Write(); + f->Close(); + } + } + +} + +///////////////////////////////////////////// +/// \brief Prints on screen the information about the metadata members of TRestDataSetTMVAClassification +/// +void TRestDataSetTMVAClassification::PrintMetadata() { + TRestMetadata::PrintMetadata(); + + RESTMetadata << " Observables to compute: " << RESTendl; + for (size_t i = 0; i < fObsName.size(); i++) { + RESTMetadata << fObsName[i] << RESTendl; + } + RESTMetadata << "----" << RESTendl; +} From dd0d31e1302b017e723da583981327016a668a01 Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 22:18:35 +0200 Subject: [PATCH 6/8] Adding tmva rml example --- examples/tmva.rml | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 examples/tmva.rml diff --git a/examples/tmva.rml b/examples/tmva.rml new file mode 100644 index 000000000..b90ed6c2c --- /dev/null +++ b/examples/tmva.rml @@ -0,0 +1,66 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 736aeb38460ac4bd221a68501fff3baa26f53d71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jun 2023 20:29:57 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/tmva.rml | 4 +- .../framework/analysis/inc/TRestDataSetTMVA.h | 35 ++--- .../inc/TRestDataSetTMVAClassification.h | 12 +- .../analysis/src/TRestDataSetTMVA.cxx | 131 ++++++++-------- .../src/TRestDataSetTMVAClassification.cxx | 140 +++++++++--------- source/framework/core/src/TRestDataSet.cxx | 2 +- source/framework/core/src/TRestRun.cxx | 7 +- 7 files changed, 167 insertions(+), 164 deletions(-) diff --git a/examples/tmva.rml b/examples/tmva.rml index b90ed6c2c..fe965ce6c 100644 --- a/examples/tmva.rml +++ b/examples/tmva.rml @@ -27,7 +27,7 @@ - + @@ -58,7 +58,7 @@ - + diff --git a/source/framework/analysis/inc/TRestDataSetTMVA.h b/source/framework/analysis/inc/TRestDataSetTMVA.h index 994c0d9db..e650a16e1 100644 --- a/source/framework/analysis/inc/TRestDataSetTMVA.h +++ b/source/framework/analysis/inc/TRestDataSetTMVA.h @@ -24,54 +24,55 @@ #define REST_TRestDataSetTMVA #include "TH1F.h" +#include "TMVA/Types.h" #include "TRestCut.h" #include "TRestMetadata.h" -#include "TMVA/Types.h" /// This class is meant to evaluate several TMVA methods in datasets class TRestDataSetTMVA : public TRestMetadata { private: /// Name of the output file - std::string fOutputFileName = ""; //< + std::string fOutputFileName = ""; //< /// Name of the signal dataSet - std::string fDataSetSignal = ""; //< + std::string fDataSetSignal = ""; //< /// Name of the background dataset - std::string fDataSetBackground = ""; //< + std::string fDataSetBackground = ""; //< /// Name of the output path for the xml files - std::string fOutputPath = ""; //< + std::string fOutputPath = ""; //< /// Vector containing different obserbable names - std::vector fObsName; //< + std::vector fObsName; //< /// Add method to compute TMVA, https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf for more details - std::vector > fMethod; //< + std::vector > fMethod; //< /// Cuts over background dataset for PDF selection - TRestCut* fBackgroundCut = nullptr; //< + TRestCut* fBackgroundCut = nullptr; //< /// Cuts over signal dataset for PDF selection - TRestCut* fSignalCut = nullptr; //< + TRestCut* fSignalCut = nullptr; //< /// If true display ROC curve after evaluating all methods - bool fDrawROCCurve = true; //< + bool fDrawROCCurve = true; //< /// Map with supported TMVA methods, please add more if something is missing - const std::map fMethodMap = { //< - {"Likelihood", TMVA::Types::kLikelihood }, // Likelihood ("naive Bayes estimator") - {"LikelihoodKDE", TMVA::Types::kLikelihood }, // Use a kernel density estimator to approximate the PDFs - {"Fisher", TMVA::Types::kFisher }, // Fisher discriminant (same as LD) - {"BDT", TMVA::Types::kBDT }, //Boosted Decision Trees - {"MLP", TMVA::Types::kMLP } //Multi-Layer Perceptron (Neural Network) + const std::map fMethodMap = { + //< + {"Likelihood", TMVA::Types::kLikelihood}, // Likelihood ("naive Bayes estimator") + {"LikelihoodKDE", + TMVA::Types::kLikelihood}, // Use a kernel density estimator to approximate the PDFs + {"Fisher", TMVA::Types::kFisher}, // Fisher discriminant (same as LD) + {"BDT", TMVA::Types::kBDT}, // Boosted Decision Trees + {"MLP", TMVA::Types::kMLP} // Multi-Layer Perceptron (Neural Network) }; void Initialize() override; void InitFromConfigFile() override; public: - void PrintMetadata() override; void ComputeTMVA(); diff --git a/source/framework/analysis/inc/TRestDataSetTMVAClassification.h b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h index fe558ee4d..9ca4ab43d 100644 --- a/source/framework/analysis/inc/TRestDataSetTMVAClassification.h +++ b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h @@ -31,22 +31,22 @@ class TRestDataSetTMVAClassification : public TRestMetadata { private: /// Name of the output file - std::string fOutputFileName = ""; //< + std::string fOutputFileName = ""; //< /// Name of the dataSet to classify - std::string fDataSetName = ""; //< + std::string fDataSetName = ""; //< /// Name of the TMVA method - std::string fTmvaMethod = ""; //< + std::string fTmvaMethod = ""; //< /// Name of the TMVA weights file - std::string fTmvaFile = ""; //< + std::string fTmvaFile = ""; //< /// Vector containing different obserbable names - std::vector fObsName; //< + std::vector fObsName; //< /// Cuts over the dataset for PDF selection - TRestCut* fCut = nullptr; //< + TRestCut* fCut = nullptr; //< void Initialize() override; void InitFromConfigFile() override; diff --git a/source/framework/analysis/src/TRestDataSetTMVA.cxx b/source/framework/analysis/src/TRestDataSetTMVA.cxx index b6ea618fb..1d957efde 100644 --- a/source/framework/analysis/src/TRestDataSetTMVA.cxx +++ b/source/framework/analysis/src/TRestDataSetTMVA.cxx @@ -27,7 +27,7 @@ /// supported. TMVA requires a signal and a background dataset from which the /// different TMVA methods are computed. The different methods are evaluated /// in a set of observables that are provided in the RML file. Different cuts -/// can be performed in either the signal or the background datasets prior to +/// can be performed in either the signal or the background datasets prior to /// the TMVA evaluation. The output of this class is a root file which contains /// a signal and a background tree with the cuts applied and the different observables /// that are generated with the TMVA analysis. In addition, a folder is created @@ -40,12 +40,12 @@ /// * **dataSetBackground**: Name of the dataset file containing the background /// * **outputPath**: Name of the output path with the evaluation results /// * **drawROCCurve**: If true display the ROC curve for the evaluation of all methods -/// +/// /// The different observables for the TMVA analysis can be added with the following key: /// \code /// /// \endcode -/// +/// /// * **name**: Name of the observable be computed /// /// The different signal and background cuts can be added awith the following key: @@ -72,7 +72,8 @@ /// /// The different TMVA methods can be added wit the following key: /// \code -/// +/// /// \endcode /// The different parameters for adding TMVA methods are described below: /// * **name**: Name of the TMVA method, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP @@ -107,11 +108,16 @@ /// /// /// -/// -/// -/// -/// -/// +/// +/// +/// +/// +/// /// /// \endcode /// @@ -148,14 +154,13 @@ #include "TRestDataSetTMVA.h" -#include "TRestDataSet.h" - +#include "ROOT/RDFHelpers.hxx" #include "TMVA/CrossValidation.h" #include "TMVA/DataLoader.h" -#include "ROOT/RDFHelpers.hxx" #include "TMVA/Factory.h" -#include "TMVA/Tools.h" #include "TMVA/TMVAGui.h" +#include "TMVA/Tools.h" +#include "TRestDataSet.h" ClassImp(TRestDataSetTMVA); @@ -249,20 +254,19 @@ void TRestDataSetTMVA::InitFromConfigFile() { std::string name = GetParameter("name", method, ""); std::string params = GetParameter("parameters", method, ""); if (name.empty() || params.empty()) { - RESTWarning << "Empty method" << RESTendl; + RESTWarning << "Empty method" << RESTendl; } else { - fMethod.push_back(std::make_pair(name, params)); + fMethod.push_back(std::make_pair(name, params)); } method = GetNextElement(method); } - if (fObsName.empty() ) { + if (fObsName.empty()) { RESTError << "No observables provided, exiting..." << RESTendl; exit(1); } if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); - } ///////////////////////////////////////////// @@ -274,17 +278,17 @@ void TRestDataSetTMVA::InitFromConfigFile() { /// root files. /// void TRestDataSetTMVA::ComputeTMVA() { - - if(fOutputFileName.empty() || fOutputPath.empty() || fDataSetSignal.empty() || fDataSetBackground.empty() ){ - RESTError <<"Empty output file name, path, signal or background files "<Get("Background"); TMVA::Factory factory("TMVA_Classification", outputFile, - "!V:ROC:!Silent:Color:AnalysisType=Classification" ); + "!V:ROC:!Silent:Color:AnalysisType=Classification"); + + TMVA::DataLoader loader(fOutputPath); - TMVA::DataLoader loader (fOutputPath); - // Add observables for the evaluation - for(const auto &obs : fObsName)loader.AddVariable(obs); + for (const auto& obs : fObsName) loader.AddVariable(obs); - loader.AddSignalTree ( signalTree, 1.0); - loader.AddBackgroundTree( bckTree, 1.0); - loader.PrepareTrainingAndTestTree( "","", - ":SplitMode=Random" - ":NormMode=NumEvents" - ":!V"); + loader.AddSignalTree(signalTree, 1.0); + loader.AddBackgroundTree(bckTree, 1.0); + loader.PrepareTrainingAndTestTree("", "", + ":SplitMode=Random" + ":NormMode=NumEvents" + ":!V"); // Add different TMVA methods - for(const auto &[name, params] : fMethod){ - auto it = fMethodMap.find(name); - if(it == fMethodMap.end() ){ - RESTWarning << "Method " << name << " not supported "<second << " " << params << std::endl; - factory.BookMethod(&loader, it->second, name.c_str(), params.c_str()); + for (const auto& [name, params] : fMethod) { + auto it = fMethodMap.find(name); + if (it == fMethodMap.end()) { + RESTWarning << "Method " << name << " not supported " << RESTendl; + RESTWarning << "Currently supported methods: "; + for (const auto& [method, val] : fMethodMap) RESTWarning << method << ", "; + RESTWarning << RESTendl; + continue; + } + std::cout << "Added method " << name << " " << it->second << " " << params << std::endl; + factory.BookMethod(&loader, it->second, name.c_str(), params.c_str()); } - // Train, test and evaluate all methods - factory.TrainAllMethods(); - factory.TestAllMethods(); - factory.EvaluateAllMethods(); + // Train, test and evaluate all methods + factory.TrainAllMethods(); + factory.TestAllMethods(); + factory.EvaluateAllMethods(); - // Draw ROC curve - if (fDrawROCCurve && gApplication != nullptr && gApplication->IsRunning()){ - auto c1 = factory.GetROCCurve(&loader); - c1->Draw(); - } - - outputFile->Close(); + // Draw ROC curve + if (fDrawROCCurve && gApplication != nullptr && gApplication->IsRunning()) { + auto c1 = factory.GetROCCurve(&loader); + c1->Draw(); + } + outputFile->Close(); } ///////////////////////////////////////////// @@ -357,12 +360,12 @@ void TRestDataSetTMVA::PrintMetadata() { TRestMetadata::PrintMetadata(); RESTMetadata << " Observables to compute: " << RESTendl; - for (const auto & obs : fObsName) { + for (const auto& obs : fObsName) { RESTMetadata << obs << RESTendl; - } + } RESTMetadata << " TMVA Methods " << RESTendl; - for(const auto &[name, params] : fMethod){ - RESTMetadata << name << " "<< params << RESTendl; - } + for (const auto& [name, params] : fMethod) { + RESTMetadata << name << " " << params << RESTendl; + } RESTMetadata << "----" << RESTendl; } diff --git a/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx index 9c64c3100..2f500f059 100644 --- a/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx +++ b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx @@ -23,7 +23,7 @@ ///////////////////////////////////////////////////////////////////////// /// TRestDataSetTMVAClassification performs the classification of a given /// dataSet using as input the results of the TMVA evaluation methods -/// generated using TRestDataSetTMVA. Note that the observables used on +/// generated using TRestDataSetTMVA. Note that the observables used on /// TRestDataSetTMVA and TRestDataSetTMVA needs to match. This class generates /// as output a dataset with a new observable which is defined using the name of /// the TMVA method that has been used to classify the dataset. Only one TMVA @@ -36,7 +36,7 @@ /// * **tmvaMethod**: Name of the TMVA method used to classify /// * **outputFileName**: Name of the output dataset /// -/// +/// /// The different observables for the TMVA classification can be added with the following key: /// \code /// @@ -72,7 +72,7 @@ /// /// /// -/// +/// /// /// /// \endcode @@ -102,16 +102,15 @@ #include "TRestDataSetTMVAClassification.h" -#include "TRestDataSet.h" - +#include "ROOT/RDFHelpers.hxx" #include "TMVA/CrossValidation.h" #include "TMVA/DataLoader.h" +#include "TMVA/Factory.h" +#include "TMVA/RInferenceUtils.hxx" #include "TMVA/RReader.hxx" -#include "ROOT/RDFHelpers.hxx" #include "TMVA/RTensorUtils.hxx" -#include "TMVA/RInferenceUtils.hxx" -#include "TMVA/Factory.h" #include "TMVA/Tools.h" +#include "TRestDataSet.h" ClassImp(TRestDataSetTMVAClassification); @@ -175,7 +174,7 @@ void TRestDataSetTMVAClassification::InitFromConfigFile() { } TiXmlElement* cutele = GetElement("addCut"); - while (cutele != nullptr) { + while (cutele != nullptr) { std::string cutName = GetParameter("name", cutele, ""); if (!cutName.empty()) { if (fCut == nullptr) { @@ -185,9 +184,9 @@ void TRestDataSetTMVAClassification::InitFromConfigFile() { } } cutele = GetNextElement(cutele); - } + } - if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); + if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); } ///////////////////////////////////////////// @@ -202,83 +201,84 @@ void TRestDataSetTMVAClassification::InitFromConfigFile() { void TRestDataSetTMVAClassification::ClassifyTMVA() { PrintMetadata(); - if (fObsName.empty() ) { + if (fObsName.empty()) { RESTError << "No observables provided, exiting..." << RESTendl; exit(1); } - TMVA::Reader reader ( "!Color:!Silent" ); - std::vector var (fObsName.size()); + TMVA::Reader reader("!Color:!Silent"); + std::vector var(fObsName.size()); - // Add variables to the reader - for(unsigned int i=0; i &val) { return reader.EvaluateMVA(val, tmvaMethod.c_str()); }; + // Lambda for evaluation of the method + auto eval = [&reader = reader, &tmvaMethod = fTmvaMethod](const std::vector& val) { + return reader.EvaluateMVA(val, tmvaMethod.c_str()); + }; - TRestDataSet dataSet; - dataSet.Import(fDataSetName); - - auto df = dataSet.MakeCut(fCut); - - std::string obsName = fTmvaMethod +"_score"; + TRestDataSet dataSet; + dataSet.Import(fDataSetName); - // Ugly but cannot pass vector size to ROOT::RDF::PassAsVec - switch (fObsName.size()){ - case 1: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<1, double>(eval), fObsName); - break; - case 2: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<2, double>(eval), fObsName); - break; - case 3: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<3, double>(eval), fObsName); - break; - case 4: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<4, double>(eval), fObsName); - break; - case 5: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<5, double>(eval), fObsName); - break; - case 6: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<6, double>(eval), fObsName); - break; - case 7: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<7, double>(eval), fObsName); - break; - case 8: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<8, double>(eval), fObsName); - break; - case 9: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<9, double>(eval), fObsName); - break; - case 10: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<10, double>(eval), fObsName); - break; - case 11: - df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<11, double>(eval), fObsName); - break; - default: - RESTError<<"Number of observables "<(eval), fObsName); + break; + case 2: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<2, double>(eval), fObsName); + break; + case 3: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<3, double>(eval), fObsName); + break; + case 4: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<4, double>(eval), fObsName); + break; + case 5: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<5, double>(eval), fObsName); + break; + case 6: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<6, double>(eval), fObsName); + break; + case 7: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<7, double>(eval), fObsName); + break; + case 8: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<8, double>(eval), fObsName); + break; + case 9: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<9, double>(eval), fObsName); + break; + case 10: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<10, double>(eval), fObsName); + break; + case 11: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<11, double>(eval), fObsName); + break; + default: + RESTError << "Number of observables " << fObsName.size() << " is not supported" << RESTendl; + exit(1); + } + + dataSet.SetDataFrame(df); + + if (!fOutputFileName.empty()) { if (TRestTools::GetFileNameExtension(fOutputFileName) == "root") { dataSet.Export(fOutputFileName); TFile* f = TFile::Open(fOutputFileName.c_str(), "UPDATE"); this->Write(); f->Close(); } - } - + } } ///////////////////////////////////////////// diff --git a/source/framework/core/src/TRestDataSet.cxx b/source/framework/core/src/TRestDataSet.cxx index 618f40667..566b279c3 100644 --- a/source/framework/core/src/TRestDataSet.cxx +++ b/source/framework/core/src/TRestDataSet.cxx @@ -918,7 +918,7 @@ void TRestDataSet::Import(const std::string& fileName, bool enableMT) { return; } - if(enableMT)ROOT::EnableImplicitMT(); + if (enableMT) ROOT::EnableImplicitMT(); fDataSet = ROOT::RDataFrame("AnalysisTree", fileName); diff --git a/source/framework/core/src/TRestRun.cxx b/source/framework/core/src/TRestRun.cxx index 9b7a88ce0..7fb9b2695 100644 --- a/source/framework/core/src/TRestRun.cxx +++ b/source/framework/core/src/TRestRun.cxx @@ -985,7 +985,7 @@ TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilena RESTDebug << "TRestRun::FormOutputFile. target : " << outputfilename << RESTendl; string filename; - TFileMerger m (false); + TFileMerger m(false); if (outputfilename == "") { filename = fOutputFileName; RESTInfo << "Creating file : " << filename << RESTendl; @@ -1083,11 +1083,10 @@ TFile* TRestRun::UpdateOutputFile() { /// TFile* TRestRun::OpenAndUpdateOutputFile() { if (fOutputFile == nullptr) { - fOutputFile = TFile::Open(fOutputFileName,"UPDATE"); + fOutputFile = TFile::Open(fOutputFileName, "UPDATE"); } - return UpdateOutputFile(); - + return UpdateOutputFile(); } /////////////////////////////////////////////// From 895872b4444f7814d83ea897352c69c590ec52c8 Mon Sep 17 00:00:00 2001 From: juanan Date: Thu, 15 Jun 2023 23:24:05 +0200 Subject: [PATCH 8/8] Adding root TMVA to CMakelist --- CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d4e7928f..ea7ed0c5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,7 +79,8 @@ set(ROOT_REQUIRED_LIBRARIES Gdml Minuit Spectrum - XMLIO) + XMLIO + TMVA) # Auto schema evolution for ROOT if (NOT DEFINED REST_SE)