diff --git a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h index 9fdb0c530d53c..4f13181d1d250 100644 --- a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h +++ b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h @@ -41,6 +41,7 @@ class RooJSONFactoryWSTool { static constexpr bool useListsInsteadOfDicts = true; static bool allowExportInvalidNames; static bool allowSanitizeNames; + static bool importNoDomainParametersAsRooConstVars; static RooWorkspace sanitizeWS(const RooWorkspace &ws); static RooWorkspace cleanWS(const RooWorkspace &ws, bool onlyModelConfig = false); diff --git a/roofit/hs3/src/Domains.cxx b/roofit/hs3/src/Domains.cxx index b7f97b859a443..2c0b32dfba52b 100644 --- a/roofit/hs3/src/Domains.cxx +++ b/roofit/hs3/src/Domains.cxx @@ -12,6 +12,8 @@ #include "Domains.h" +#include +#include #include #include #include @@ -25,6 +27,28 @@ namespace Detail { constexpr static auto defaultDomainName = "default_domain"; +namespace { + +double readBound(RooFit::Detail::JSONNode const &node, const char *key, double defaultValue) +{ + if (!node.has_child(key)) { + return defaultValue; + } + auto const &bound = node[key]; + return bound.is_null() ? defaultValue : bound.val_double(); +} + +void writeBound(RooFit::Detail::JSONNode &node, double value) +{ + if (RooNumber::isInfinite(value)) { + node.set_null(); + } else { + node << value; + } +} + +} // namespace + void Domains::populate(RooWorkspace &ws) const { auto default_domain = _map.find(defaultDomainName); @@ -49,12 +73,12 @@ void Domains::readVariable(const char *name, double min, double max, const char void Domains::readVariable(RooRealVar const &var) { - readVariable(var.GetName(), var.getMin(), var.getMax(), defaultDomainName); + _map[defaultDomainName].readVariable(var.GetName(), var.getBinning()); for (const auto &bname : var.getBinningNames()) { if (bname.empty()) continue; auto &binning = var.getBinning(bname.c_str()); - readVariable(var.GetName(), binning.lowBound(), binning.highBound(), bname.c_str()); + _map[bname].readVariable(var.GetName(), binning); } } @@ -91,25 +115,89 @@ void Domains::writeJSON(RooFit::Detail::JSONNode &node) const } } +bool Domains::hasVariable(const char *name) const +{ + for (auto const &domain : _map) { + if (domain.second.hasVariable(name)) { + return true; + } + } + return false; +} + void Domains::ProductDomain::readVariable(RooRealVar const &var) { - readVariable(var.GetName(), var.getMin(), var.getMax()); + readVariable(var.GetName(), var.getBinning()); } -void Domains::ProductDomain::readVariable(const char *name, double min, double max) +void Domains::ProductDomain::readBinning(ProductDomainElement &elem, RooAbsBinning const &binning) { - if (RooNumber::isInfinite(min) && RooNumber::isInfinite(max)) + elem.hasNBins = false; + elem.nBins = 0; + elem.edges.clear(); + + const int nBins = binning.numBins(); + if (nBins <= 0) { return; + } + + if (binning.isUniform()) { + elem.hasNBins = true; + elem.nBins = nBins; + } else { + elem.edges.push_back(binning.binLow(0)); + for (int i = 0; i < nBins; ++i) { + elem.edges.push_back(binning.binHigh(i)); + } + } +} +void Domains::ProductDomain::readVariable(const char *name, RooAbsBinning const &binning) +{ auto &elem = _map[name]; - if (!RooNumber::isInfinite(min)) { - elem.hasMin = true; - elem.min = min; + elem.hasMin = true; + elem.min = binning.lowBound(); + elem.hasMax = true; + elem.max = binning.highBound(); + readBinning(elem, binning); +} + +void Domains::ProductDomain::readVariable(const char *name, double min, double max) +{ + auto &elem = _map[name]; + + elem.hasMin = true; + elem.min = min; + elem.hasMax = true; + elem.max = max; + elem.hasNBins = false; + elem.nBins = 0; + elem.edges.clear(); +} + +void Domains::ProductDomain::applyBinning(RooRealVar &var, ProductDomainElement const &elem, const char *name) +{ + if (!elem.edges.empty()) { + RooBinning binning(elem.edges.front(), elem.edges.back()); + for (double edge : elem.edges) { + binning.addBoundary(edge); + } + var.setBinning(binning, name); + } else if (elem.hasNBins && elem.nBins != 0) { + var.setBins(elem.nBins, name); } - if (!RooNumber::isInfinite(max)) { - elem.hasMax = true; - elem.max = max; +} + +void Domains::ProductDomain::writeBinning(RooFit::Detail::JSONNode &node, ProductDomainElement const &elem) +{ + if (!elem.edges.empty()) { + auto &edges = node["edges"].set_seq(); + for (double edge : elem.edges) { + edges.append_child() << edge; + } + } else if (elem.hasNBins && elem.nBins != 0) { + node["nbins"] << elem.nBins; } } void Domains::ProductDomain::writeVariable(RooRealVar &var) const @@ -117,12 +205,29 @@ void Domains::ProductDomain::writeVariable(RooRealVar &var) const auto found = _map.find(var.GetName()); if (found != _map.end()) { auto const &elem = found->second; - if (elem.hasMin) - var.setMin(elem.min); - if (elem.hasMax) - var.setMax(elem.max); + if (elem.hasMin) { + if (RooNumber::isInfinite(elem.min)) { + var.removeMin(); + } else { + var.setMin(elem.min); + } + } + if (elem.hasMax) { + if (RooNumber::isInfinite(elem.max)) { + var.removeMax(); + } else { + var.setMax(elem.max); + } + } + applyBinning(var, elem); } } + +bool Domains::ProductDomain::hasVariable(const char *name) const +{ + return _map.find(name) != _map.end(); +} + void Domains::ProductDomain::readJSON(RooFit::Detail::JSONNode const &node) { if (!node.has_child("type") || node["type"].val() != "product_domain") { @@ -132,13 +237,34 @@ void Domains::ProductDomain::readJSON(RooFit::Detail::JSONNode const &node) auto &elem = _map[RooJSONFactoryWSTool::name(varNode)]; if (varNode.has_child("min")) { - elem.min = varNode["min"].val_double(); + elem.min = readBound(varNode, "min", -RooNumber::infinity()); elem.hasMin = true; } if (varNode.has_child("max")) { - elem.max = varNode["max"].val_double(); + elem.max = readBound(varNode, "max", RooNumber::infinity()); elem.hasMax = true; } + if (varNode.has_child("edges")) { + elem.hasNBins = false; + elem.edges.clear(); + for (auto const &edge : varNode["edges"].children()) { + elem.edges.push_back(edge.val_double()); + } + if (!elem.edges.empty()) { + if (!elem.hasMin) { + elem.min = elem.edges.front(); + elem.hasMin = true; + } + if (!elem.hasMax) { + elem.max = elem.edges.back(); + elem.hasMax = true; + } + } + } else if (varNode.has_child("nbins")) { + elem.nBins = varNode["nbins"].val_int(); + elem.hasNBins = elem.nBins != 0; + elem.edges.clear(); + } } } void Domains::ProductDomain::writeJSON(RooFit::Detail::JSONNode &node) const @@ -152,21 +278,22 @@ void Domains::ProductDomain::writeJSON(RooFit::Detail::JSONNode &node) const for (auto const &item : _map) { auto const &elem = item.second; RooFit::Detail::JSONNode &varnode = RooJSONFactoryWSTool::appendNamedChild(variablesNode, item.first); - if (elem.hasMin) - varnode["min"] << elem.min; - if (elem.hasMax) - varnode["max"] << elem.max; + writeBound(varnode["min"], elem.hasMin ? elem.min : -RooNumber::infinity()); + writeBound(varnode["max"], elem.hasMax ? elem.max : RooNumber::infinity()); + writeBinning(varnode, elem); } } void Domains::ProductDomain::populate(RooWorkspace &ws) const { for (auto const &item : _map) { const auto &name = item.first; - if (!ws.var(name)) { + if (!ws.arg(name)) { const auto &elem = item.second; const double vMin = elem.hasMin ? elem.min : -RooNumber::infinity(); const double vMax = elem.hasMax ? elem.max : RooNumber::infinity(); - ws.import(RooRealVar{name.c_str(), name.c_str(), vMin, vMax}); + RooRealVar var{name.c_str(), name.c_str(), vMin, vMax}; + applyBinning(var, elem); + ws.import(var); } } } @@ -176,7 +303,10 @@ void Domains::ProductDomain::registerBinnings(const char *name, RooWorkspace &ws auto *var = ws.var(item.first); if (!var) continue; - var->setRange(name, item.second.min, item.second.max); + const double vMin = item.second.hasMin ? item.second.min : -RooNumber::infinity(); + const double vMax = item.second.hasMax ? item.second.max : RooNumber::infinity(); + var->setRange(name, vMin, vMax); + applyBinning(*var, item.second, name); } } diff --git a/roofit/hs3/src/Domains.h b/roofit/hs3/src/Domains.h index d9ccae059970d..f45f78c108800 100644 --- a/roofit/hs3/src/Domains.h +++ b/roofit/hs3/src/Domains.h @@ -17,6 +17,7 @@ #include #include +class RooAbsBinning; class RooRealVar; class RooWorkspace; @@ -42,17 +43,22 @@ class Domains { void readJSON(RooFit::Detail::JSONNode const &); void writeJSON(RooFit::Detail::JSONNode &) const; + bool hasVariable(const char *name) const; + void populate(RooWorkspace &ws) const; class ProductDomain { public: void readVariable(const RooRealVar &); + void readVariable(const char *name, RooAbsBinning const &binning); void readVariable(const char *name, double min, double max); void writeVariable(RooRealVar &) const; void readJSON(RooFit::Detail::JSONNode const &); void writeJSON(RooFit::Detail::JSONNode &) const; + bool hasVariable(const char *name) const; + void populate(RooWorkspace &ws) const; void registerBinnings(const char *name, RooWorkspace &ws) const; @@ -62,8 +68,15 @@ class Domains { bool hasMax = false; double min = 0.0; double max = 0.0; + bool hasNBins = false; + int nBins = 0; + std::vector edges; }; + static void applyBinning(RooRealVar &var, ProductDomainElement const &elem, const char *name = nullptr); + static void readBinning(ProductDomainElement &elem, RooAbsBinning const &binning); + static void writeBinning(RooFit::Detail::JSONNode &node, ProductDomainElement const &elem); + std::map _map; }; diff --git a/roofit/hs3/src/JSONFactories_HistFactory.cxx b/roofit/hs3/src/JSONFactories_HistFactory.cxx index ea1b3de1672af..eb1844d49ebd3 100644 --- a/roofit/hs3/src/JSONFactories_HistFactory.cxx +++ b/roofit/hs3/src/JSONFactories_HistFactory.cxx @@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa "'"); } } +double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma) +{ + auto const *mean = dynamic_cast(&constraint.getMean()); + if (!mean) { + RooJSONFactoryWSTool::error("Poisson gamma constraint mean is not a RooProduct: " + + std::string(constraint.GetName())); + } + + for (RooAbsArg *arg : mean->servers()) { + if (arg == &gamma) { + continue; + } + + if (auto const *tau = dynamic_cast(arg)) { + return tau->getVal(); + } + + // Imported workspaces can sometimes represent + // constants as constant RooRealVars. + if (auto const *real = dynamic_cast(arg)) { + if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) { + return real->getVal(); + } + } + } + + RooJSONFactoryWSTool::error("Could not find tau component in Poisson gamma constraint mean: " + + std::string(constraint.GetName())); + return std::numeric_limits::quiet_NaN(); +} bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist, RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p, @@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con // this is dealt with at a different place, ignore it for now } else if (modtype == "normfactor") { RooRealVar &constrParam = getOrCreate(ws, sysname, 1., -3, 5); + constrParam.setError(0.0); normElems.add(constrParam); if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) { // for norm factors, constraints are optional @@ -1054,7 +1085,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons if (constraint) { sample.barlowBeestonLightConstraintType = constraint->IsA(); if (RooPoisson *constraint_p = dynamic_cast(constraint)) { - double erel = 1. / std::sqrt(constraint_p->getX().getVal()); + double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g)); channel.rel_errors[idx] = erel; } else if (RooGaussian *constraint_g = dynamic_cast(constraint)) { double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal(); @@ -1088,7 +1119,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons if (!constraint) { sys.constraints.push_back(0.0); } else if (auto constraint_p = dynamic_cast(constraint)) { - sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal())); + sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g))); if (!sys.constraint) { sys.constraintType = RooPoisson::Class(); } diff --git a/roofit/hs3/src/JSONFactories_RooFitCore.cxx b/roofit/hs3/src/JSONFactories_RooFitCore.cxx index 42fe920e04e18..6a9fd7ffad156 100644 --- a/roofit/hs3/src/JSONFactories_RooFitCore.cxx +++ b/roofit/hs3/src/JSONFactories_RooFitCore.cxx @@ -282,9 +282,9 @@ class RooPolynomialFactory : public RooFit::JSONIO::Importer { // As long as the coefficients match the default coefficients in // RooFit, we don't have to instantiate RooFit objects but can // increase the lowestOrder flag. - if (order == 0 && coef.val() == "1.0") { + if (order == 0 && (coef.val() == "1.0" || coef.val() == "1")) { ++lowestOrder; - } else if (coefs.empty() && coef.val() == "0.0") { + } else if (coefs.empty() && (coef.val() == "0.0" || coef.val() == "0")) { ++lowestOrder; } else { coefs.add(*tool->request(coef.val(), name)); @@ -817,7 +817,7 @@ void writePolynomialBody(const Pdf *pdf, JSONNode &elem) elem["x"] << pdf->x().GetName(); auto &coefs = elem["coefficients"].set_seq(); for (int i = 0; i < pdf->lowestOrder(); ++i) { - coefs.append_child() << (i == 0 ? "1.0" : "0.0"); + coefs.append_child() << (i == 0 ? 1.0 : 0.0); } for (const auto &coef : pdf->coefList()) { coefs.append_child() << coef->GetName(); diff --git a/roofit/hs3/src/RooJSONFactoryWSTool.cxx b/roofit/hs3/src/RooJSONFactoryWSTool.cxx index 65634d6da2a40..0d93c738aae54 100644 --- a/roofit/hs3/src/RooJSONFactoryWSTool.cxx +++ b/roofit/hs3/src/RooJSONFactoryWSTool.cxx @@ -162,6 +162,26 @@ struct Var { Var(int n) : nbins(n), min(0), max(n) {} }; +void exportAxis(JSONNode &obsNode, RooRealVar const &var) +{ + std::string name = var.GetName(); + RooJSONFactoryWSTool::testValidName(name, false); + obsNode["name"] << name; + + auto const &binning = var.getBinning(); + if (binning.isUniform()) { + obsNode["min"] << var.getMin(); + obsNode["max"] << var.getMax(); + obsNode["nbins"] << var.getBins(); + } else { + auto &edges = obsNode["edges"].set_seq(); + edges.append_child() << binning.binLow(0); + for (int i = 0; i < binning.numBins(); ++i) { + edges.append_child() << binning.binHigh(i); + } + } +} + /** * @brief Check if a string represents a valid number. * @@ -519,9 +539,6 @@ void exportAttributes(const RooAbsArg *arg, JSONNode &rootnode) node = &RooJSONFactoryWSTool::getRooFitInternal(rootnode, "attributes").set_map()[arg->GetName()].set_map(); }; - // RooConstVars are not a thing in HS3, and also for RooFit they are not - // that important: they are just constants. So we don't need to remember - // any information about them. if (dynamic_cast(arg)) { return; } @@ -966,6 +983,7 @@ bool RooJSONFactoryWSTool::isValidName(const std::string &str) bool RooJSONFactoryWSTool::allowExportInvalidNames(true); bool RooJSONFactoryWSTool::allowSanitizeNames(true); +bool RooJSONFactoryWSTool::importNoDomainParametersAsRooConstVars(true); bool RooJSONFactoryWSTool::testValidName(const std::string &name, bool forceError) { if (!RooJSONFactoryWSTool::isValidName(name)) { @@ -1032,6 +1050,8 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl(const std::string &obj return pdf; if (RooRealVar *var = requestImpl(objname)) return var; + if (RooAbsReal *retval = _workspace.function(objname)) + return retval; auto it = _functionsByName.find(objname); if (it != _functionsByName.end()) { this->importFunction(*it->second, true); @@ -1072,11 +1092,11 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, bo var["value"] << rrv->getVal(); if (rrv->isConstant() && storeConstant) { var["const"] << rrv->isConstant(); - } else { + } else if (storeBins) { var["min"] << rrv->getMin(); var["max"] << rrv->getMax(); } - if (rrv->getBins() != 100 && storeBins) { + if (rrv->getBins() != 0 && storeBins) { var["nbins"] << rrv->getBins(); } _domains->readVariable(*rrv); @@ -1642,7 +1662,11 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data) // this really is an unbinned dataset output["type"] << "unbinned"; - exportVariables(variables, output["axes"], false, true); + auto &observablesNode = output["axes"].set_seq(); + for (auto *var : static_range_cast(variables)) { + _domains->readVariable(*var); + exportAxis(observablesNode.append_child().set_map(), *var); + } auto &coords = output["entries"].set_seq(); std::vector weightVals; bool hasNonUnityWeights = false; @@ -1769,7 +1793,7 @@ void RooJSONFactoryWSTool::importVariable(const JSONNode &p) std::string name(RooJSONFactoryWSTool::name(p)); RooJSONFactoryWSTool::testValidName(name, true); - if (_workspace.var(name)) + if (_workspace.arg(name)) return; if (!p.is_map()) { std::stringstream ss; @@ -1777,14 +1801,12 @@ void RooJSONFactoryWSTool::importVariable(const JSONNode &p) oocoutE(nullptr, InputArguments) << ss.str() << std::endl; return; } - if (_attributesNode) { - if (auto *attrNode = _attributesNode->find(name)) { - // We should not create RooRealVar objects for RooConstVars! - if (attrNode->has_child("is_const_var") && (*attrNode)["is_const_var"].val_int() == 1) { - wsEmplace(name, p["value"].val_double()); - return; - } + if (importNoDomainParametersAsRooConstVars && !_domains->hasVariable(name.c_str())) { + if (!p.has_child("value")) { + RooJSONFactoryWSTool::error("cannot instantiate RooConstVar '" + name + "' without \"value\"!"); } + wsEmplace(name, p["value"].val_double()); + return; } configureVariable(*_domains, p, wsEmplace(name, 1.)); } @@ -2240,16 +2262,16 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n) error(ss.str()); } + _rootnodeInput = &n; + + _attributesNode = findRooFitInternal(*_rootnodeInput, "attributes"); + _domains = std::make_unique(); if (auto domains = n.find("domains")) { _domains->readJSON(*domains); } _domains->populate(_workspace); - _rootnodeInput = &n; - - _attributesNode = findRooFitInternal(*_rootnodeInput, "attributes"); - // Build name-keyed indices over the "functions" and "distributions" // sequences. Without these, every cross-reference resolved during import // (e.g. dependencies of a PiecewiseInterpolation, or factory-expression diff --git a/roofit/hs3/test/testRooFitHS3.cxx b/roofit/hs3/test/testRooFitHS3.cxx index d949326ccd43f..3fcf10dc044de 100644 --- a/roofit/hs3/test/testRooFitHS3.cxx +++ b/roofit/hs3/test/testRooFitHS3.cxx @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include #include @@ -146,6 +148,38 @@ std::string parameterStepWidthsNode(std::string const &json) return json.substr(begin, end - begin + 1); } +std::string defaultDomainAxesNode(std::string const &json) +{ + const std::string key = "\"domains\":["; + const auto domainsBegin = json.find(key); + if (domainsBegin == std::string::npos) { + return ""; + } + const auto axesBegin = json.find("\"axes\":[", domainsBegin); + if (axesBegin == std::string::npos) { + return ""; + } + const auto axesEnd = json.find("}]", axesBegin); + if (axesEnd == std::string::npos) { + return ""; + } + return json.substr(axesBegin, axesEnd - axesBegin + 2); +} + +class ScopedNoDomainConstVarImportFlag { +public: + explicit ScopedNoDomainConstVarImportFlag(bool value) + : _oldValue{RooJSONFactoryWSTool::importNoDomainParametersAsRooConstVars} + { + RooJSONFactoryWSTool::importNoDomainParametersAsRooConstVars = value; + } + + ~ScopedNoDomainConstVarImportFlag() { RooJSONFactoryWSTool::importNoDomainParametersAsRooConstVars = _oldValue; } + +private: + bool _oldValue; +}; + } // namespace // Test that the IO of attributes and string attributes works. @@ -179,6 +213,118 @@ TEST(RooFitHS3, AttributesIO) EXPECT_STREQ(pdf.getStringAttribute("key1"), nullptr) << "unexpected string attribute found!"; } +TEST(RooFitHS3, ParameterPointsDoNotExportRanges) +{ + RooWorkspace ws{"workspace"}; + ws.factory("Gaussian::pdf(x[0, 10], mu[1, -5, 5], sigma[2, 0.1, 10])"); + + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + auto tree = RooFit::Detail::JSONTree::create(json); + + for (auto const &point : tree->rootnode()["parameter_points"].children()) { + for (auto const ¶meter : point["parameters"].children()) { + EXPECT_FALSE(parameter.has_child("min")) << parameter["name"].val(); + EXPECT_FALSE(parameter.has_child("max")) << parameter["name"].val(); + } + } +} + +TEST(RooFitHS3, ProductDomainEntriesExportExplicitBounds) +{ + RooRealVar x{"x", "x", 0.0, -10.0, 10.0}; + RooRealVar mean{"mean", "mean", 0.0}; + RooRealVar sigma{"sigma", "sigma", 1.0, 0.1, 10.0}; + RooGaussian gauss{"gauss", "gauss", x, mean, sigma}; + + RooWorkspace ws{"workspace"}; + ws.import(gauss, RooFit::Silence()); + + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + auto tree = RooFit::Detail::JSONTree::create(json); + auto const *defaultDomain = RooJSONFactoryWSTool::findNamedChild(tree->rootnode()["domains"], "default_domain"); + ASSERT_NE(defaultDomain, nullptr); + + auto const *xAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "x"); + ASSERT_NE(xAxis, nullptr); + ASSERT_TRUE(xAxis->has_child("min")); + ASSERT_TRUE(xAxis->has_child("max")); + EXPECT_FALSE((*xAxis)["min"].is_null()); + EXPECT_FALSE((*xAxis)["max"].is_null()); + EXPECT_DOUBLE_EQ((*xAxis)["min"].val_double(), -10.0); + EXPECT_DOUBLE_EQ((*xAxis)["max"].val_double(), 10.0); + + auto const *meanAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "mean"); + ASSERT_NE(meanAxis, nullptr); + ASSERT_TRUE(meanAxis->has_child("min")); + ASSERT_TRUE(meanAxis->has_child("max")); + EXPECT_TRUE((*meanAxis)["min"].is_null()); + EXPECT_TRUE((*meanAxis)["max"].is_null()); + + RooWorkspace imported; + ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json)); + auto *importedMean = imported.var("mean"); + ASSERT_NE(importedMean, nullptr); + EXPECT_TRUE(std::isinf(importedMean->getMin())); + EXPECT_LT(importedMean->getMin(), 0.0); + EXPECT_TRUE(std::isinf(importedMean->getMax())); + EXPECT_GT(importedMean->getMax(), 0.0); +} + +TEST(RooFitHS3, ProductDomainEntriesExportBinning) +{ + RooRealVar uniform{"uniform", "uniform", 0.0, 1.0}; + uniform.setBins(7); + + RooRealVar nonuniform{"nonuniform", "nonuniform", 0.0, 3.0}; + RooBinning nonuniformBinning{0.0, 3.0}; + nonuniformBinning.addBoundary(1.0); + nonuniformBinning.addBoundary(1.5); + nonuniform.setBinning(nonuniformBinning); + + RooAddition sum{"sum", "sum", RooArgList{uniform, nonuniform}}; + + RooWorkspace ws{"workspace"}; + ws.import(sum, RooFit::Silence()); + + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + auto tree = RooFit::Detail::JSONTree::create(json); + auto const *defaultDomain = RooJSONFactoryWSTool::findNamedChild(tree->rootnode()["domains"], "default_domain"); + ASSERT_NE(defaultDomain, nullptr); + + auto const *uniformAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "uniform"); + ASSERT_NE(uniformAxis, nullptr); + ASSERT_TRUE(uniformAxis->has_child("nbins")); + EXPECT_EQ((*uniformAxis)["nbins"].val_int(), 7); + EXPECT_FALSE(uniformAxis->has_child("edges")); + + auto const *nonuniformAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "nonuniform"); + ASSERT_NE(nonuniformAxis, nullptr); + ASSERT_TRUE(nonuniformAxis->has_child("edges")); + EXPECT_FALSE(nonuniformAxis->has_child("nbins")); + auto const &edges = (*nonuniformAxis)["edges"]; + ASSERT_EQ(edges.num_children(), 4u); + EXPECT_DOUBLE_EQ(edges.child(0).val_double(), 0.0); + EXPECT_DOUBLE_EQ(edges.child(1).val_double(), 1.0); + EXPECT_DOUBLE_EQ(edges.child(2).val_double(), 1.5); + EXPECT_DOUBLE_EQ(edges.child(3).val_double(), 3.0); + + RooWorkspace imported; + ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json)); + auto *importedUniform = imported.var("uniform"); + ASSERT_NE(importedUniform, nullptr); + EXPECT_EQ(importedUniform->getBins(), 7); + + auto *importedNonuniform = imported.var("nonuniform"); + ASSERT_NE(importedNonuniform, nullptr); + auto const &importedBinning = importedNonuniform->getBinning(); + EXPECT_FALSE(importedBinning.isUniform()); + ASSERT_EQ(importedBinning.numBins(), 3); + EXPECT_DOUBLE_EQ(importedBinning.binLow(0), 0.0); + EXPECT_DOUBLE_EQ(importedBinning.binHigh(0), 1.0); + EXPECT_DOUBLE_EQ(importedBinning.binHigh(1), 1.5); + EXPECT_DOUBLE_EQ(importedBinning.binHigh(2), 3.0); +} + TEST(RooFitHS3, ParameterStepWidthsModelConfigRoundTrip) { RooWorkspace ws1{"workspace"}; @@ -294,6 +440,7 @@ TEST(RooFitHS3, ParameterStepWidthsImportAfterDefaultSnapshot) })"; RooWorkspace ws{"workspace"}; + ScopedNoDomainConstVarImportFlag flagGuard{false}; ASSERT_TRUE(RooJSONFactoryWSTool{ws}.importJSONfromString(json)); ASSERT_NE(ws.var("mu"), nullptr); @@ -395,6 +542,115 @@ TEST(RooFitHS3, RooGaussian) EXPECT_EQ(status, 0); } +TEST(RooFitHS3, RooGaussianConstVarSigmaExport) +{ + ScopedNoDomainConstVarImportFlag flagGuard{true}; + + RooRealVar x{"x", "x", 0.0, -10.0, 10.0}; + RooRealVar mean{"mean", "mean", 0.0}; + mean.setConstant(true); + + RooConstVar sigmaConst{"sigma_const", "sigma_const", 1.0}; + RooGaussian gaussConst{"gauss_const", "gauss_const", x, mean, sigmaConst}; + + RooGaussian gaussLiteral{"gauss_literal", "gauss_literal", x, mean, RooFit::RooConst(2.0)}; + + RooRealVar sigmaReal{"sigma_real", "sigma_real", 1.0, 0.1, 10.0}; + sigmaReal.setConstant(true); + RooGaussian gaussReal{"gauss_real", "gauss_real", x, mean, sigmaReal}; + + RooWorkspace ws; + ws.import(gaussConst, RooFit::Silence()); + ws.import(gaussLiteral, RooFit::RecycleConflictNodes(), RooFit::Silence()); + ws.import(gaussReal, RooFit::RecycleConflictNodes(), RooFit::Silence()); + + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + const std::string domainAxes = defaultDomainAxesNode(json); + ASSERT_FALSE(domainAxes.empty()) << json; + + EXPECT_NE(json.find("\"sigma\":\"sigma_const\""), std::string::npos); + EXPECT_NE(json.find("\"name\":\"sigma_const\""), std::string::npos); + EXPECT_EQ(json.find("is_const_var"), std::string::npos); + EXPECT_EQ(json.find("\"sigma\":1.0"), std::string::npos); + EXPECT_NE(json.find("\"sigma\":2.0"), std::string::npos); + EXPECT_EQ(domainAxes.find("\"name\":\"sigma_const\""), std::string::npos) << domainAxes; + + EXPECT_NE(json.find("\"sigma\":\"sigma_real\""), std::string::npos); + EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos); + EXPECT_NE(domainAxes.find("\"name\":\"sigma_real\""), std::string::npos) << domainAxes; + + // The unbounded constant RooRealVar is still a RooRealVar, so it gets a + // domain axis with explicit null bounds that distinguishes it from a RooConstVar. + auto tree = RooFit::Detail::JSONTree::create(json); + auto const *defaultDomain = RooJSONFactoryWSTool::findNamedChild(tree->rootnode()["domains"], "default_domain"); + ASSERT_NE(defaultDomain, nullptr); + auto const *meanAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "mean"); + ASSERT_NE(meanAxis, nullptr); + EXPECT_TRUE((*meanAxis)["min"].is_null()); + EXPECT_TRUE((*meanAxis)["max"].is_null()); + + RooWorkspace imported; + ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json)); + EXPECT_NE(dynamic_cast(imported.obj("sigma_const")), nullptr); + EXPECT_EQ(imported.var("sigma_const"), nullptr); + EXPECT_NE(dynamic_cast(imported.obj("sigma_real")), nullptr); + EXPECT_NE(dynamic_cast(imported.obj("mean")), nullptr); + + const std::string roundTripJson = RooJSONFactoryWSTool{imported}.exportJSONtoString(); + const std::string roundTripDomainAxes = defaultDomainAxesNode(roundTripJson); + EXPECT_NE(roundTripJson.find("\"sigma\":\"sigma_const\""), std::string::npos); + EXPECT_EQ(roundTripJson.find("is_const_var"), std::string::npos); + EXPECT_EQ(roundTripDomainAxes.find("\"name\":\"sigma_const\""), std::string::npos) << roundTripDomainAxes; + + const std::string legacyJson = R"({ + "metadata":{"hs3_version":"0.2"}, + "parameter_points":[{"name":"default_values","parameters":[ + {"name":"x","value":0.0}, + {"name":"mean","value":0.0}, + {"name":"sigma_const","value":1.0,"const":true} + ]}], + "distributions":[{"name":"gauss","type":"gaussian_dist","x":"x","mean":"mean","sigma":"sigma_const"}] + })"; + RooWorkspace legacyImport; + { + ScopedNoDomainConstVarImportFlag legacyFlagGuard{false}; + ASSERT_TRUE(RooJSONFactoryWSTool{legacyImport}.importJSONfromString(legacyJson)); + } + EXPECT_NE(dynamic_cast(legacyImport.obj("sigma_const")), nullptr); + EXPECT_EQ(dynamic_cast(legacyImport.obj("sigma_const")), nullptr); +} + +TEST(RooFitHS3, RooConstVarCollectionProxyExport) +{ + ScopedNoDomainConstVarImportFlag flagGuard{true}; + + RooRealVar x{"x", "x", 0.0, -10.0, 10.0}; + RooRealVar mean1{"mean1", "mean1", -1.0}; + RooRealVar mean2{"mean2", "mean2", 1.0}; + RooRealVar sigma{"sigma", "sigma", 1.0, 0.1, 10.0}; + + RooGaussian g1{"g1", "g1", x, mean1, sigma}; + RooGaussian g2{"g2", "g2", x, mean2, sigma}; + RooConstVar frac{"frac_const", "frac_const", 0.25}; + RooAddPdf model{"model", "model", RooArgList{g1, g2}, RooArgList{frac}}; + + RooWorkspace ws; + ws.import(model, RooFit::Silence()); + + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + const std::string domainAxes = defaultDomainAxesNode(json); + ASSERT_FALSE(domainAxes.empty()) << json; + EXPECT_NE(json.find("\"coefficients\":[\"frac_const\"]"), std::string::npos); + EXPECT_NE(json.find("\"name\":\"frac_const\""), std::string::npos); + EXPECT_EQ(json.find("is_const_var"), std::string::npos); + EXPECT_EQ(domainAxes.find("\"name\":\"frac_const\""), std::string::npos) << domainAxes; + + RooWorkspace imported; + ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json)); + EXPECT_NE(dynamic_cast(imported.obj("frac_const")), nullptr); + EXPECT_EQ(imported.var("frac_const"), nullptr); +} + TEST(RooFitHS3, RooBernstein) { int status = validate({"RooBernstein::bernstein(x[0, 10], { a[1], 3, b[5, 0, 20] })"}); @@ -506,6 +762,12 @@ TEST(RooFitHS3, RooPolynomial) EXPECT_EQ(status, 0); status = validate({"Polynomial::poly1(x[0, 10], {a_2[0.003, -10, 10]}, 2)"}); EXPECT_EQ(status, 0); + + RooWorkspace ws; + ws.factory("Polynomial::poly1(x[0, 10], {a_2[0.003, -10, 10]}, 2)"); + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + EXPECT_NE(json.find("\"coefficients\":[1.0,0.0,\"a_2\"]"), std::string::npos) << json; + EXPECT_EQ(json.find("\"coefficients\":[\"1.0\""), std::string::npos) << json; } TEST(RooFitHS3, RooPowerSum) @@ -824,6 +1086,7 @@ TEST(RooFitHS3, UnbinnedDatasetAxisRange) const std::string axesNode = json1.substr(axesPos, axesEnd - axesPos); EXPECT_NE(axesNode.find("\"min\":-2.5"), std::string::npos) << axesNode; EXPECT_NE(axesNode.find("\"max\":7.5"), std::string::npos) << axesNode; + EXPECT_EQ(axesNode.find("\"value\""), std::string::npos) << axesNode; } // HistFactory channels with samples that have a zero-yield bin together with a diff --git a/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h b/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h index 245673fcf577f..ff0c1e695f430 100644 --- a/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h +++ b/roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h @@ -100,8 +100,10 @@ class JSONNode { virtual bool is_container() const = 0; virtual bool is_map() const = 0; virtual bool is_seq() const = 0; + virtual bool is_null() const = 0; virtual JSONNode &set_map() = 0; virtual JSONNode &set_seq() = 0; + virtual JSONNode &set_null() = 0; virtual void clear() = 0; virtual std::string key() const = 0; diff --git a/roofit/jsoninterface/src/JSONParser.cxx b/roofit/jsoninterface/src/JSONParser.cxx index e93b32b7cae76..d1d470b267598 100644 --- a/roofit/jsoninterface/src/JSONParser.cxx +++ b/roofit/jsoninterface/src/JSONParser.cxx @@ -183,6 +183,11 @@ bool TJSONTree::Node::is_seq() const return node->get().is_array(); } +bool TJSONTree::Node::is_null() const +{ + return node->get().is_null(); +} + namespace { // To check whether it's allowed to reset the type of an object. We allow @@ -231,6 +236,12 @@ TJSONTree::Node &TJSONTree::Node::set_seq() return *this; } +TJSONTree::Node &TJSONTree::Node::set_null() +{ + node->get() = nullptr; + return *this; +} + void TJSONTree::Node::clear() { node->get().clear(); diff --git a/roofit/jsoninterface/src/JSONParser.h b/roofit/jsoninterface/src/JSONParser.h index d05625df6657e..241bee4d796b2 100644 --- a/roofit/jsoninterface/src/JSONParser.h +++ b/roofit/jsoninterface/src/JSONParser.h @@ -53,8 +53,10 @@ class TJSONTree : public RooFit::Detail::JSONTree { bool is_container() const override; bool is_map() const override; bool is_seq() const override; + bool is_null() const override; Node &set_map() override; Node &set_seq() override; + Node &set_null() override; void clear() override; std::string key() const override; std::string val() const override; diff --git a/roofit/jsoninterface/src/RYMLParser.cxx b/roofit/jsoninterface/src/RYMLParser.cxx index b04d0cc5978db..ea15f589fb8b5 100644 --- a/roofit/jsoninterface/src/RYMLParser.cxx +++ b/roofit/jsoninterface/src/RYMLParser.cxx @@ -123,6 +123,11 @@ TRYMLTree::Node &TRYMLTree::Node::set_seq() return *this; } +TRYMLTree::Node &TRYMLTree::Node::set_null() +{ + throw std::logic_error("Function not yet implemented"); +} + void TRYMLTree::Node::clear() { throw std::logic_error("Function not yet implemented"); @@ -227,6 +232,11 @@ bool TRYMLTree::Node::is_seq() const return node->get().is_seq(); } +bool TRYMLTree::Node::is_null() const +{ + return !node->get().has_val() && node->get().num_children() == 0; +} + std::string TRYMLTree::Node::key() const { // obtain the key of this node diff --git a/roofit/jsoninterface/src/RYMLParser.h b/roofit/jsoninterface/src/RYMLParser.h index 673d196f0ef2e..7be8bda5c4291 100644 --- a/roofit/jsoninterface/src/RYMLParser.h +++ b/roofit/jsoninterface/src/RYMLParser.h @@ -50,8 +50,10 @@ class TRYMLTree : public RooFit::Detail::JSONTree { bool is_container() const override; bool is_map() const override; bool is_seq() const override; + bool is_null() const override; Node &set_map() override; Node &set_seq() override; + Node &set_null() override; void clear() override; std::string key() const override; std::string val() const override;