Skip to content

Commit

Permalink
Add stitching (#33)
Browse files Browse the repository at this point in the history
* added dy_process_ids producer und dy datasets

* forgot the linting

* fixed njet variable

* linting

* Continue generalizing dy process ids. (#4)

* Continue generalizing dy process ids.

* Simplification.

* Typo.

* Typo.

* define stiching ranges with aux dict

* typo

* Full draft for dy identifier.

* Simplify njets lookup, improve typing.

* Rename variables.

* update cf

* Update cf.

* Use stiched producer.

* Update cf.

* Update cf.

---------

Co-authored-by: haddadanas <hhhaaanas@gmail.com>
Co-authored-by: Anas <103462379+haddadanas@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent 5fd5a37 commit 031150d
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 7 deletions.
35 changes: 35 additions & 0 deletions hbt/config/configs_hbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def if_era(
# "ttz_llnunu_amcatnlo", not available
# "ttw_nlu_amcatnlo", not available
# "ttw_qq_amcatnlo", not available
"ttz_zqq_amcatnlo",
"ttzz_madgraph",
# "ttwz_madgraph", not available
"ttww_madgraph",
Expand All @@ -184,6 +185,19 @@ def if_era(
"dy_m4to10_amcatnlo",
"dy_m10to50_amcatnlo",
"dy_m50toinf_amcatnlo",
"dy_m50toinf_0j_amcatnlo",
"dy_m50toinf_1j_amcatnlo",
"dy_m50toinf_2j_amcatnlo",
"dy_m50toinf_1j_pt40to100_amcatnlo",
"dy_m50toinf_1j_pt100to200_amcatnlo",
"dy_m50toinf_1j_pt200to400_amcatnlo",
"dy_m50toinf_1j_pt400to600_amcatnlo",
"dy_m50toinf_1j_pt600toinf_amcatnlo",
"dy_m50toinf_2j_pt40to100_amcatnlo",
"dy_m50toinf_2j_pt100to200_amcatnlo",
"dy_m50toinf_2j_pt200to400_amcatnlo",
"dy_m50toinf_2j_pt400to600_amcatnlo",
"dy_m50toinf_2j_pt600toinf_amcatnlo",
"w_lnu_amcatnlo",
# "ewk_wm_lnu_m50toinf_madgraph", not available
# "ewk_w_lnu_m50toinf_madgraph", not available
Expand All @@ -204,6 +218,8 @@ def if_era(
"wmh_wlnu_hbb_powheg",
"wph_wlnu_hbb_powheg",
"zh_gg_zll_hbb_powheg",
"zh_gg_znunu_hbb_powheg",
"zh_gg_zqq_hbb_powheg",
# "wph_tautau_powheg", not available
# "wmh_tautau_powheg", not available
# "tth_tautau_powheg", not available
Expand All @@ -225,6 +241,8 @@ def if_era(
dataset.add_tag(("has_top", "is_ttbar"))
elif dataset.name.startswith("st"):
dataset.add_tag(("has_top", "is_single_top"))
if dataset.name.startswith("dy"):
dataset.add_tag("is_dy")
if re.match(r"^(ww|wz|zz)_.*pythia$", dataset.name):
dataset.add_tag("no_lhe_weights")

Expand Down Expand Up @@ -270,6 +288,23 @@ def if_era(
"sm_data": ["data"] + sm_group,
}

# define inclusive datasets for the dy process identification with corresponding leaf processes
cfg.x.dy_stitching = {
"m50toinf": {
"inclusive_dataset": cfg.datasets.n.dy_m50toinf_amcatnlo,
"leaf_processes": [
# the following processes cover the full njet and pt phasespace
procs.n.dy_m50toinf_0j,
*(
procs.get(f"dy_m50toinf_{nj}j_pt{pt}")
for nj in [1, 2]
for pt in ["0to40", "40to100", "100to200", "200to400", "400to600", "600toinf"]
),
procs.n.dy_m50toinf_ge3j,
],
},
}

# dataset groups for conveniently looping over certain datasets
# (used in wrapper_factory and during plotting)
cfg.x.dataset_groups = {}
Expand Down
12 changes: 7 additions & 5 deletions hbt/production/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
"""

from columnflow.production import Producer, producer
from columnflow.production.normalization import normalization_weights
from columnflow.production.normalization import stitched_normalization_weights
from columnflow.production.categories import category_ids
from columnflow.production.cms.electron import electron_weights
from columnflow.production.cms.muon import muon_weights
from columnflow.util import maybe_import

from hbt.production.features import features
from hbt.production.weights import normalized_pu_weight, normalized_pdf_weight, normalized_murmuf_weight
from hbt.production.weights import (
normalized_pu_weight, normalized_pdf_weight, normalized_murmuf_weight,
)
from hbt.production.btag import normalized_btag_weights
from hbt.production.tau import tau_weights, trigger_weights
from hbt.util import IF_DATASET_HAS_LHE_WEIGHTS
Expand All @@ -23,12 +25,12 @@

@producer(
uses={
category_ids, features, normalization_weights, normalized_pu_weight,
category_ids, features, stitched_normalization_weights, normalized_pu_weight,
normalized_btag_weights, tau_weights, electron_weights, muon_weights, trigger_weights,
IF_DATASET_HAS_LHE_WEIGHTS(normalized_pdf_weight, normalized_murmuf_weight),
},
produces={
category_ids, features, normalization_weights, normalized_pu_weight,
category_ids, features, stitched_normalization_weights, normalized_pu_weight,
normalized_btag_weights, tau_weights, electron_weights, muon_weights, trigger_weights,
IF_DATASET_HAS_LHE_WEIGHTS(normalized_pdf_weight, normalized_murmuf_weight),
},
Expand All @@ -43,7 +45,7 @@ def default(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
# mc-only weights
if self.dataset_inst.is_mc:
# normalization weights
events = self[normalization_weights](events, **kwargs)
events = self[stitched_normalization_weights](events, **kwargs)

# normalized pdf weight
if self.has_dep(normalized_pdf_weight):
Expand Down
142 changes: 142 additions & 0 deletions hbt/production/processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# coding: utf-8

"""
Process ID producer relevant for the stitching of the DY samples.
"""

import functools

import law

from columnflow.production import Producer, producer
from columnflow.util import maybe_import, InsertableDict
from columnflow.columnar_util import set_ak_column

from hbt.util import IF_DATASET_IS_DY

np = maybe_import("numpy")
ak = maybe_import("awkward")
sp = maybe_import("scipy")
maybe_import("scipy.sparse")


logger = law.logger.get_logger(__name__)

NJetsRange = tuple[int, int]
PtRange = tuple[float, float]

set_ak_column_i64 = functools.partial(set_ak_column, value_type=np.int64)


@producer(
uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")},
produces={IF_DATASET_IS_DY("process_id")},
)
def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
"""
Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of
the LHE record. This is used for the stitching of the DY samples.
"""
# as always, we assume that each dataset has exactly one process associated to it
if len(self.dataset_inst.processes) != 1:
raise NotImplementedError(
f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
"assigned, which is not yet implemented",
)
process_inst = self.dataset_inst.processes.get_first()

# get the number of nlo jets and the di-lepton pt
njets = events.LHE.NpNLO
pt = events.LHE.Vpt

# raise a warning if a datasets was already created for a specific "bin" (leaf process),
# but actually does not fit
njets_range = process_inst.x("njets", None)
if njets_range is not None:
outliers = (njets < njets_range[0]) | (njets >= njets_range[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain njet values in the range "
f"[{njets_range[0]}, {njets_range[0]}), but found {ak.sum(outliers)} events "
"outside this range",
)
pt_range = process_inst.x("ptll", None)
if pt_range is not None:
outliers = (pt < pt_range[0]) | (pt >= pt_range[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain ptll values in the range "
f"[{pt_range[0]}, {pt_range[1]}), but found {ak.sum(outliers)} events outside this "
"range",
)

# lookup the id and check for invalid values
process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(njets, pt)].todense()))
invalid_mask = process_ids == 0
if ak.any(invalid_mask):
raise ValueError(
f"found {sum(invalid_mask)} dy events that could not be assigned to a process",
)

# store them
events = set_ak_column_i64(events, "process_id", process_ids)

return events


@process_ids_dy.setup
def process_ids_dy_setup(
self: Producer,
reqs: dict,
inputs: dict,
reader_targets: InsertableDict,
) -> None:
# define stitching ranges for the DY datasets covered by this producer's dy_inclusive_dataset
stitching_ranges: dict[NJetsRange, list[PtRange]] = {}
for proc in self.dy_leaf_processes:
njets = proc.x.njets
stitching_ranges.setdefault(njets, [])
if proc.has_aux("ptll"):
stitching_ranges[njets].append(proc.x.ptll)

# sort by the first element of the ptll range
sorted_stitching_ranges: list[tuple[NJetsRange, list[PtRange]]] = [
(nj_range, sorted(stitching_ranges[nj_range], key=lambda ptll_range: ptll_range[0]))
for nj_range in sorted(stitching_ranges.keys(), key=lambda nj_range: nj_range[0])
]

# define a key function that maps njets and pt to a unique key for use in a lookup table
def key_func(njets, pt):
# potentially convert single values into arrays
single = False
if isinstance(njets, int):
assert isinstance(pt, (int, float))
njets = np.array([njets], dtype=np.int32)
pt = np.array([pt], dtype=np.float32)
single = True

# map into bins (index 0 means no binning)
nj_bins = np.zeros(len(njets), dtype=np.int32)
pt_bins = np.zeros(len(pt), dtype=np.int32)
for nj_bin, (nj_range, pt_ranges) in enumerate(sorted_stitching_ranges, 1):
# nj_bin
nj_mask = (nj_range[0] <= njets) & (njets < nj_range[1])
nj_bins[nj_mask] = nj_bin
# pt_bin
for pt_bin, (pt_min, pt_max) in enumerate(pt_ranges, 1):
pt_mask = (pt_min <= pt) & (pt < pt_max)
pt_bins[nj_mask & pt_mask] = pt_bin

return (nj_bins[0], pt_bins[0]) if single else (nj_bins, pt_bins)

self.key_func = key_func

# define the lookup table
max_nj_bin = len(sorted_stitching_ranges)
max_pt_bin = max(map(len, stitching_ranges.values()))
self.id_table = sp.sparse.lil_matrix((max_nj_bin + 1, max_pt_bin + 1), dtype=np.int64)

# fill it
for proc in self.dy_leaf_processes:
key = key_func(proc.x.njets[0], proc.x("ptll", [-1])[0])
self.id_table[key] = proc.id
33 changes: 32 additions & 1 deletion hbt/selection/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Selection methods.
"""

from __future__ import annotations

from operator import and_
from functools import reduce
from collections import defaultdict
Expand All @@ -25,6 +27,7 @@
from hbt.selection.lepton import lepton_selection
from hbt.selection.jet import jet_selection
from hbt.production.features import cutflow_features
from hbt.production.processes import process_ids_dy
from hbt.util import IF_DATASET_HAS_LHE_WEIGHTS

np = maybe_import("numpy")
Expand Down Expand Up @@ -114,7 +117,10 @@ def default(
)

# create process ids
events = self[process_ids](events, **kwargs)
if self.process_ids_dy is not None:
events = self[self.process_ids_dy](events, **kwargs)
else:
events = self[process_ids](events, **kwargs)

# some cutflow features
events = self[cutflow_features](events, results.objects, **kwargs)
Expand Down Expand Up @@ -177,3 +183,28 @@ def default(
)

return events, results


@default.init
def default_init(self: Selector) -> None:
if getattr(self, "dataset_inst", None) is None:
return

self.process_ids_dy: process_ids_dy | None = None
if self.dataset_inst.has_tag("is_dy"):
# check if this dataset is covered by any dy id producer
for name, dy_cfg in self.config_inst.x.dy_stitching.items():
dataset_inst = dy_cfg["inclusive_dataset"]
# the dataset is "covered" if its process is a subprocess of that of the dy dataset
if dataset_inst.has_process(self.dataset_inst.processes.get_first()):
self.process_ids_dy = process_ids_dy.derive(f"process_ids_dy_{name}", cls_dict={
"dy_inclusive_dataset": dataset_inst,
"dy_leaf_processes": dy_cfg["leaf_processes"],
})

# add it as a dependency
self.uses.add(self.process_ids_dy)
self.produces.add(self.process_ids_dy)

# stop after the first match
break
11 changes: 11 additions & 0 deletions hbt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ def IF_DATASET_HAS_LHE_WEIGHTS(
return self.get()

return None if func.dataset_inst.has_tag("no_lhe_weights") else self.get()


@deferred_column
def IF_DATASET_IS_DY(
self: ArrayFunction.DeferredColumn,
func: ArrayFunction,
) -> Any | set[Any]:
if getattr(func, "dataset_inst", None) is None:
return self.get()

return self.get() if func.dataset_inst.has_tag("is_dy") else None

0 comments on commit 031150d

Please sign in to comment.