Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class RooJSONFactoryWSTool {
static constexpr bool useListsInsteadOfDicts = true;
static bool allowExportInvalidNames;
static bool allowSanitizeNames;
static bool importNoDomainParametersAsRooConstVars;
static RooWorkspace sanitizeWS(const RooWorkspace &ws);
static RooWorkspace cleanWS(const RooWorkspace &ws, bool onlyModelConfig = false);

Expand Down
178 changes: 154 additions & 24 deletions roofit/hs3/src/Domains.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "Domains.h"

#include <RooAbsBinning.h>
#include <RooBinning.h>
#include <RooFitHS3/RooJSONFactoryWSTool.h>
#include <RooNumber.h>
#include <RooRealVar.h>
Expand All @@ -25,6 +27,28 @@ namespace Detail {

constexpr static auto defaultDomainName = "default_domain";

namespace {

double readBound(RooFit::Detail::JSONNode const &node, const char *key, double defaultValue)
{
if (!node.has_child(key)) {
return defaultValue;
}
auto const &bound = node[key];
return bound.is_null() ? defaultValue : bound.val_double();
}

void writeBound(RooFit::Detail::JSONNode &node, double value)
{
if (RooNumber::isInfinite(value)) {
node.set_null();
} else {
node << value;
}
}

} // namespace

void Domains::populate(RooWorkspace &ws) const
{
auto default_domain = _map.find(defaultDomainName);
Expand All @@ -49,12 +73,12 @@ void Domains::readVariable(const char *name, double min, double max, const char

void Domains::readVariable(RooRealVar const &var)
{
readVariable(var.GetName(), var.getMin(), var.getMax(), defaultDomainName);
_map[defaultDomainName].readVariable(var.GetName(), var.getBinning());
for (const auto &bname : var.getBinningNames()) {
if (bname.empty())
continue;
auto &binning = var.getBinning(bname.c_str());
readVariable(var.GetName(), binning.lowBound(), binning.highBound(), bname.c_str());
_map[bname].readVariable(var.GetName(), binning);
}
}

Expand Down Expand Up @@ -91,38 +115,119 @@ void Domains::writeJSON(RooFit::Detail::JSONNode &node) const
}
}

bool Domains::hasVariable(const char *name) const
{
for (auto const &domain : _map) {
if (domain.second.hasVariable(name)) {
return true;
}
}
return false;
}

void Domains::ProductDomain::readVariable(RooRealVar const &var)
{
readVariable(var.GetName(), var.getMin(), var.getMax());
readVariable(var.GetName(), var.getBinning());
}

void Domains::ProductDomain::readVariable(const char *name, double min, double max)
void Domains::ProductDomain::readBinning(ProductDomainElement &elem, RooAbsBinning const &binning)
{
if (RooNumber::isInfinite(min) && RooNumber::isInfinite(max))
elem.hasNBins = false;
elem.nBins = 0;
elem.edges.clear();

const int nBins = binning.numBins();
if (nBins <= 0) {
return;
}

if (binning.isUniform()) {
elem.hasNBins = true;
elem.nBins = nBins;
} else {
elem.edges.push_back(binning.binLow(0));
for (int i = 0; i < nBins; ++i) {
elem.edges.push_back(binning.binHigh(i));
}
}
}

void Domains::ProductDomain::readVariable(const char *name, RooAbsBinning const &binning)
{
auto &elem = _map[name];

if (!RooNumber::isInfinite(min)) {
elem.hasMin = true;
elem.min = min;
elem.hasMin = true;
elem.min = binning.lowBound();
elem.hasMax = true;
elem.max = binning.highBound();
readBinning(elem, binning);
}

void Domains::ProductDomain::readVariable(const char *name, double min, double max)
{
auto &elem = _map[name];

elem.hasMin = true;
elem.min = min;
elem.hasMax = true;
elem.max = max;
elem.hasNBins = false;
elem.nBins = 0;
elem.edges.clear();
}

void Domains::ProductDomain::applyBinning(RooRealVar &var, ProductDomainElement const &elem, const char *name)
{
if (!elem.edges.empty()) {
RooBinning binning(elem.edges.front(), elem.edges.back());
for (double edge : elem.edges) {
binning.addBoundary(edge);
}
var.setBinning(binning, name);
} else if (elem.hasNBins && elem.nBins != 0) {
var.setBins(elem.nBins, name);
}
if (!RooNumber::isInfinite(max)) {
elem.hasMax = true;
elem.max = max;
}

void Domains::ProductDomain::writeBinning(RooFit::Detail::JSONNode &node, ProductDomainElement const &elem)
{
if (!elem.edges.empty()) {
auto &edges = node["edges"].set_seq();
for (double edge : elem.edges) {
edges.append_child() << edge;
}
} else if (elem.hasNBins && elem.nBins != 0) {
node["nbins"] << elem.nBins;
}
}
void Domains::ProductDomain::writeVariable(RooRealVar &var) const
{
auto found = _map.find(var.GetName());
if (found != _map.end()) {
auto const &elem = found->second;
if (elem.hasMin)
var.setMin(elem.min);
if (elem.hasMax)
var.setMax(elem.max);
if (elem.hasMin) {
if (RooNumber::isInfinite(elem.min)) {
var.removeMin();
} else {
var.setMin(elem.min);
}
}
if (elem.hasMax) {
if (RooNumber::isInfinite(elem.max)) {
var.removeMax();
} else {
var.setMax(elem.max);
}
}
applyBinning(var, elem);
}
}

bool Domains::ProductDomain::hasVariable(const char *name) const
{
return _map.find(name) != _map.end();
}

void Domains::ProductDomain::readJSON(RooFit::Detail::JSONNode const &node)
{
if (!node.has_child("type") || node["type"].val() != "product_domain") {
Expand All @@ -132,13 +237,34 @@ void Domains::ProductDomain::readJSON(RooFit::Detail::JSONNode const &node)
auto &elem = _map[RooJSONFactoryWSTool::name(varNode)];

if (varNode.has_child("min")) {
elem.min = varNode["min"].val_double();
elem.min = readBound(varNode, "min", -RooNumber::infinity());
elem.hasMin = true;
}
if (varNode.has_child("max")) {
elem.max = varNode["max"].val_double();
elem.max = readBound(varNode, "max", RooNumber::infinity());
elem.hasMax = true;
}
if (varNode.has_child("edges")) {
elem.hasNBins = false;
elem.edges.clear();
for (auto const &edge : varNode["edges"].children()) {
elem.edges.push_back(edge.val_double());
}
if (!elem.edges.empty()) {
if (!elem.hasMin) {
elem.min = elem.edges.front();
elem.hasMin = true;
}
if (!elem.hasMax) {
elem.max = elem.edges.back();
elem.hasMax = true;
}
}
} else if (varNode.has_child("nbins")) {
elem.nBins = varNode["nbins"].val_int();
elem.hasNBins = elem.nBins != 0;
elem.edges.clear();
}
}
}
void Domains::ProductDomain::writeJSON(RooFit::Detail::JSONNode &node) const
Expand All @@ -152,21 +278,22 @@ void Domains::ProductDomain::writeJSON(RooFit::Detail::JSONNode &node) const
for (auto const &item : _map) {
auto const &elem = item.second;
RooFit::Detail::JSONNode &varnode = RooJSONFactoryWSTool::appendNamedChild(variablesNode, item.first);
if (elem.hasMin)
varnode["min"] << elem.min;
if (elem.hasMax)
varnode["max"] << elem.max;
writeBound(varnode["min"], elem.hasMin ? elem.min : -RooNumber::infinity());
writeBound(varnode["max"], elem.hasMax ? elem.max : RooNumber::infinity());
writeBinning(varnode, elem);
}
}
void Domains::ProductDomain::populate(RooWorkspace &ws) const
{
for (auto const &item : _map) {
const auto &name = item.first;
if (!ws.var(name)) {
if (!ws.arg(name)) {
const auto &elem = item.second;
const double vMin = elem.hasMin ? elem.min : -RooNumber::infinity();
const double vMax = elem.hasMax ? elem.max : RooNumber::infinity();
ws.import(RooRealVar{name.c_str(), name.c_str(), vMin, vMax});
RooRealVar var{name.c_str(), name.c_str(), vMin, vMax};
applyBinning(var, elem);
ws.import(var);
}
}
}
Expand All @@ -176,7 +303,10 @@ void Domains::ProductDomain::registerBinnings(const char *name, RooWorkspace &ws
auto *var = ws.var(item.first);
if (!var)
continue;
var->setRange(name, item.second.min, item.second.max);
const double vMin = item.second.hasMin ? item.second.min : -RooNumber::infinity();
const double vMax = item.second.hasMax ? item.second.max : RooNumber::infinity();
var->setRange(name, vMin, vMax);
applyBinning(*var, item.second, name);
}
}

Expand Down
13 changes: 13 additions & 0 deletions roofit/hs3/src/Domains.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <map>
#include <vector>

class RooAbsBinning;
class RooRealVar;
class RooWorkspace;

Expand All @@ -42,17 +43,22 @@ class Domains {
void readJSON(RooFit::Detail::JSONNode const &);
void writeJSON(RooFit::Detail::JSONNode &) const;

bool hasVariable(const char *name) const;

void populate(RooWorkspace &ws) const;

class ProductDomain {
public:
void readVariable(const RooRealVar &);
void readVariable(const char *name, RooAbsBinning const &binning);
void readVariable(const char *name, double min, double max);
void writeVariable(RooRealVar &) const;

void readJSON(RooFit::Detail::JSONNode const &);
void writeJSON(RooFit::Detail::JSONNode &) const;

bool hasVariable(const char *name) const;

void populate(RooWorkspace &ws) const;
void registerBinnings(const char *name, RooWorkspace &ws) const;

Expand All @@ -62,8 +68,15 @@ class Domains {
bool hasMax = false;
double min = 0.0;
double max = 0.0;
bool hasNBins = false;
int nBins = 0;
std::vector<double> edges;
};

static void applyBinning(RooRealVar &var, ProductDomainElement const &elem, const char *name = nullptr);
static void readBinning(ProductDomainElement &elem, RooAbsBinning const &binning);
static void writeBinning(RooFit::Detail::JSONNode &node, ProductDomainElement const &elem);

std::map<std::string, ProductDomainElement> _map;
};

Expand Down
35 changes: 33 additions & 2 deletions roofit/hs3/src/JSONFactories_HistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa
"'");
}
}
double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma)
{
auto const *mean = dynamic_cast<RooProduct const *>(&constraint.getMean());
if (!mean) {
RooJSONFactoryWSTool::error("Poisson gamma constraint mean is not a RooProduct: " +
std::string(constraint.GetName()));
}

for (RooAbsArg *arg : mean->servers()) {
if (arg == &gamma) {
continue;
}

if (auto const *tau = dynamic_cast<RooConstVar const *>(arg)) {
return tau->getVal();
}

// Imported workspaces can sometimes represent
// constants as constant RooRealVars.
if (auto const *real = dynamic_cast<RooAbsReal const *>(arg)) {
if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) {
return real->getVal();
}
}
}

RooJSONFactoryWSTool::error("Could not find tau component in Poisson gamma constraint mean: " +
std::string(constraint.GetName()));
return std::numeric_limits<double>::quiet_NaN();
}

bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist,
RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p,
Expand Down Expand Up @@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con
// this is dealt with at a different place, ignore it for now
} else if (modtype == "normfactor") {
RooRealVar &constrParam = getOrCreate<RooRealVar>(ws, sysname, 1., -3, 5);
constrParam.setError(0.0);
normElems.add(constrParam);
if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) {
// for norm factors, constraints are optional
Expand Down Expand Up @@ -1054,7 +1085,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
if (constraint) {
sample.barlowBeestonLightConstraintType = constraint->IsA();
if (RooPoisson *constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g));
channel.rel_errors[idx] = erel;
} else if (RooGaussian *constraint_g = dynamic_cast<RooGaussian *>(constraint)) {
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
Expand Down Expand Up @@ -1088,7 +1119,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
if (!constraint) {
sys.constraints.push_back(0.0);
} else if (auto constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal()));
sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g)));
if (!sys.constraint) {
sys.constraintType = RooPoisson::Class();
}
Expand Down
Loading
Loading