From 031150d52453bad9cbdcfaaff4f24a99c86bd544 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Thu, 18 Jul 2024 19:45:14 +0200 Subject: [PATCH] Add stitching (#33) * added dy_process_ids producer und dy datasets * forgot the linting * fixed njet variable * linting * Continue generalizing dy process ids. (#4) * Continue generalizing dy process ids. * Simplification. * Typo. * Typo. * define stiching ranges with aux dict * typo * Full draft for dy identifier. * Simplify njets lookup, improve typing. * Rename variables. * update cf * Update cf. * Use stiched producer. * Update cf. * Update cf. --------- Co-authored-by: haddadanas Co-authored-by: Anas <103462379+haddadanas@users.noreply.github.com> --- hbt/config/configs_hbt.py | 35 +++++++++ hbt/production/default.py | 12 +-- hbt/production/processes.py | 142 ++++++++++++++++++++++++++++++++++++ hbt/selection/default.py | 33 ++++++++- hbt/util.py | 11 +++ modules/columnflow | 2 +- 6 files changed, 228 insertions(+), 7 deletions(-) create mode 100644 hbt/production/processes.py diff --git a/hbt/config/configs_hbt.py b/hbt/config/configs_hbt.py index d57a010..f49583d 100644 --- a/hbt/config/configs_hbt.py +++ b/hbt/config/configs_hbt.py @@ -168,6 +168,7 @@ def if_era( # "ttz_llnunu_amcatnlo", not available # "ttw_nlu_amcatnlo", not available # "ttw_qq_amcatnlo", not available + "ttz_zqq_amcatnlo", "ttzz_madgraph", # "ttwz_madgraph", not available "ttww_madgraph", @@ -184,6 +185,19 @@ def if_era( "dy_m4to10_amcatnlo", "dy_m10to50_amcatnlo", "dy_m50toinf_amcatnlo", + "dy_m50toinf_0j_amcatnlo", + "dy_m50toinf_1j_amcatnlo", + "dy_m50toinf_2j_amcatnlo", + "dy_m50toinf_1j_pt40to100_amcatnlo", + "dy_m50toinf_1j_pt100to200_amcatnlo", + "dy_m50toinf_1j_pt200to400_amcatnlo", + "dy_m50toinf_1j_pt400to600_amcatnlo", + "dy_m50toinf_1j_pt600toinf_amcatnlo", + "dy_m50toinf_2j_pt40to100_amcatnlo", + "dy_m50toinf_2j_pt100to200_amcatnlo", + "dy_m50toinf_2j_pt200to400_amcatnlo", + "dy_m50toinf_2j_pt400to600_amcatnlo", + "dy_m50toinf_2j_pt600toinf_amcatnlo", "w_lnu_amcatnlo", # "ewk_wm_lnu_m50toinf_madgraph", not available # "ewk_w_lnu_m50toinf_madgraph", not available @@ -204,6 +218,8 @@ def if_era( "wmh_wlnu_hbb_powheg", "wph_wlnu_hbb_powheg", "zh_gg_zll_hbb_powheg", + "zh_gg_znunu_hbb_powheg", + "zh_gg_zqq_hbb_powheg", # "wph_tautau_powheg", not available # "wmh_tautau_powheg", not available # "tth_tautau_powheg", not available @@ -225,6 +241,8 @@ def if_era( dataset.add_tag(("has_top", "is_ttbar")) elif dataset.name.startswith("st"): dataset.add_tag(("has_top", "is_single_top")) + if dataset.name.startswith("dy"): + dataset.add_tag("is_dy") if re.match(r"^(ww|wz|zz)_.*pythia$", dataset.name): dataset.add_tag("no_lhe_weights") @@ -270,6 +288,23 @@ def if_era( "sm_data": ["data"] + sm_group, } + # define inclusive datasets for the dy process identification with corresponding leaf processes + cfg.x.dy_stitching = { + "m50toinf": { + "inclusive_dataset": cfg.datasets.n.dy_m50toinf_amcatnlo, + "leaf_processes": [ + # the following processes cover the full njet and pt phasespace + procs.n.dy_m50toinf_0j, + *( + procs.get(f"dy_m50toinf_{nj}j_pt{pt}") + for nj in [1, 2] + for pt in ["0to40", "40to100", "100to200", "200to400", "400to600", "600toinf"] + ), + procs.n.dy_m50toinf_ge3j, + ], + }, + } + # dataset groups for conveniently looping over certain datasets # (used in wrapper_factory and during plotting) cfg.x.dataset_groups = {} diff --git a/hbt/production/default.py b/hbt/production/default.py index dbda5bd..3270457 100644 --- a/hbt/production/default.py +++ b/hbt/production/default.py @@ -5,14 +5,16 @@ """ from columnflow.production import Producer, producer -from columnflow.production.normalization import normalization_weights +from columnflow.production.normalization import stitched_normalization_weights from columnflow.production.categories import category_ids from columnflow.production.cms.electron import electron_weights from columnflow.production.cms.muon import muon_weights from columnflow.util import maybe_import from hbt.production.features import features -from hbt.production.weights import normalized_pu_weight, normalized_pdf_weight, normalized_murmuf_weight +from hbt.production.weights import ( + normalized_pu_weight, normalized_pdf_weight, normalized_murmuf_weight, +) from hbt.production.btag import normalized_btag_weights from hbt.production.tau import tau_weights, trigger_weights from hbt.util import IF_DATASET_HAS_LHE_WEIGHTS @@ -23,12 +25,12 @@ @producer( uses={ - category_ids, features, normalization_weights, normalized_pu_weight, + category_ids, features, stitched_normalization_weights, normalized_pu_weight, normalized_btag_weights, tau_weights, electron_weights, muon_weights, trigger_weights, IF_DATASET_HAS_LHE_WEIGHTS(normalized_pdf_weight, normalized_murmuf_weight), }, produces={ - category_ids, features, normalization_weights, normalized_pu_weight, + category_ids, features, stitched_normalization_weights, normalized_pu_weight, normalized_btag_weights, tau_weights, electron_weights, muon_weights, trigger_weights, IF_DATASET_HAS_LHE_WEIGHTS(normalized_pdf_weight, normalized_murmuf_weight), }, @@ -43,7 +45,7 @@ def default(self: Producer, events: ak.Array, **kwargs) -> ak.Array: # mc-only weights if self.dataset_inst.is_mc: # normalization weights - events = self[normalization_weights](events, **kwargs) + events = self[stitched_normalization_weights](events, **kwargs) # normalized pdf weight if self.has_dep(normalized_pdf_weight): diff --git a/hbt/production/processes.py b/hbt/production/processes.py new file mode 100644 index 0000000..17e3901 --- /dev/null +++ b/hbt/production/processes.py @@ -0,0 +1,142 @@ +# coding: utf-8 + +""" +Process ID producer relevant for the stitching of the DY samples. +""" + +import functools + +import law + +from columnflow.production import Producer, producer +from columnflow.util import maybe_import, InsertableDict +from columnflow.columnar_util import set_ak_column + +from hbt.util import IF_DATASET_IS_DY + +np = maybe_import("numpy") +ak = maybe_import("awkward") +sp = maybe_import("scipy") +maybe_import("scipy.sparse") + + +logger = law.logger.get_logger(__name__) + +NJetsRange = tuple[int, int] +PtRange = tuple[float, float] + +set_ak_column_i64 = functools.partial(set_ak_column, value_type=np.int64) + + +@producer( + uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")}, + produces={IF_DATASET_IS_DY("process_id")}, +) +def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array: + """ + Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of + the LHE record. This is used for the stitching of the DY samples. + """ + # as always, we assume that each dataset has exactly one process associated to it + if len(self.dataset_inst.processes) != 1: + raise NotImplementedError( + f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes " + "assigned, which is not yet implemented", + ) + process_inst = self.dataset_inst.processes.get_first() + + # get the number of nlo jets and the di-lepton pt + njets = events.LHE.NpNLO + pt = events.LHE.Vpt + + # raise a warning if a datasets was already created for a specific "bin" (leaf process), + # but actually does not fit + njets_range = process_inst.x("njets", None) + if njets_range is not None: + outliers = (njets < njets_range[0]) | (njets >= njets_range[1]) + if ak.any(outliers): + logger.warning( + f"dataset {self.dataset_inst.name} is meant to contain njet values in the range " + f"[{njets_range[0]}, {njets_range[0]}), but found {ak.sum(outliers)} events " + "outside this range", + ) + pt_range = process_inst.x("ptll", None) + if pt_range is not None: + outliers = (pt < pt_range[0]) | (pt >= pt_range[1]) + if ak.any(outliers): + logger.warning( + f"dataset {self.dataset_inst.name} is meant to contain ptll values in the range " + f"[{pt_range[0]}, {pt_range[1]}), but found {ak.sum(outliers)} events outside this " + "range", + ) + + # lookup the id and check for invalid values + process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(njets, pt)].todense())) + invalid_mask = process_ids == 0 + if ak.any(invalid_mask): + raise ValueError( + f"found {sum(invalid_mask)} dy events that could not be assigned to a process", + ) + + # store them + events = set_ak_column_i64(events, "process_id", process_ids) + + return events + + +@process_ids_dy.setup +def process_ids_dy_setup( + self: Producer, + reqs: dict, + inputs: dict, + reader_targets: InsertableDict, +) -> None: + # define stitching ranges for the DY datasets covered by this producer's dy_inclusive_dataset + stitching_ranges: dict[NJetsRange, list[PtRange]] = {} + for proc in self.dy_leaf_processes: + njets = proc.x.njets + stitching_ranges.setdefault(njets, []) + if proc.has_aux("ptll"): + stitching_ranges[njets].append(proc.x.ptll) + + # sort by the first element of the ptll range + sorted_stitching_ranges: list[tuple[NJetsRange, list[PtRange]]] = [ + (nj_range, sorted(stitching_ranges[nj_range], key=lambda ptll_range: ptll_range[0])) + for nj_range in sorted(stitching_ranges.keys(), key=lambda nj_range: nj_range[0]) + ] + + # define a key function that maps njets and pt to a unique key for use in a lookup table + def key_func(njets, pt): + # potentially convert single values into arrays + single = False + if isinstance(njets, int): + assert isinstance(pt, (int, float)) + njets = np.array([njets], dtype=np.int32) + pt = np.array([pt], dtype=np.float32) + single = True + + # map into bins (index 0 means no binning) + nj_bins = np.zeros(len(njets), dtype=np.int32) + pt_bins = np.zeros(len(pt), dtype=np.int32) + for nj_bin, (nj_range, pt_ranges) in enumerate(sorted_stitching_ranges, 1): + # nj_bin + nj_mask = (nj_range[0] <= njets) & (njets < nj_range[1]) + nj_bins[nj_mask] = nj_bin + # pt_bin + for pt_bin, (pt_min, pt_max) in enumerate(pt_ranges, 1): + pt_mask = (pt_min <= pt) & (pt < pt_max) + pt_bins[nj_mask & pt_mask] = pt_bin + + return (nj_bins[0], pt_bins[0]) if single else (nj_bins, pt_bins) + + self.key_func = key_func + + # define the lookup table + max_nj_bin = len(sorted_stitching_ranges) + max_pt_bin = max(map(len, stitching_ranges.values())) + self.id_table = sp.sparse.lil_matrix((max_nj_bin + 1, max_pt_bin + 1), dtype=np.int64) + + # fill it + for proc in self.dy_leaf_processes: + key = key_func(proc.x.njets[0], proc.x("ptll", [-1])[0]) + self.id_table[key] = proc.id diff --git a/hbt/selection/default.py b/hbt/selection/default.py index cfaa7be..1076449 100644 --- a/hbt/selection/default.py +++ b/hbt/selection/default.py @@ -4,6 +4,8 @@ Selection methods. """ +from __future__ import annotations + from operator import and_ from functools import reduce from collections import defaultdict @@ -25,6 +27,7 @@ from hbt.selection.lepton import lepton_selection from hbt.selection.jet import jet_selection from hbt.production.features import cutflow_features +from hbt.production.processes import process_ids_dy from hbt.util import IF_DATASET_HAS_LHE_WEIGHTS np = maybe_import("numpy") @@ -114,7 +117,10 @@ def default( ) # create process ids - events = self[process_ids](events, **kwargs) + if self.process_ids_dy is not None: + events = self[self.process_ids_dy](events, **kwargs) + else: + events = self[process_ids](events, **kwargs) # some cutflow features events = self[cutflow_features](events, results.objects, **kwargs) @@ -177,3 +183,28 @@ def default( ) return events, results + + +@default.init +def default_init(self: Selector) -> None: + if getattr(self, "dataset_inst", None) is None: + return + + self.process_ids_dy: process_ids_dy | None = None + if self.dataset_inst.has_tag("is_dy"): + # check if this dataset is covered by any dy id producer + for name, dy_cfg in self.config_inst.x.dy_stitching.items(): + dataset_inst = dy_cfg["inclusive_dataset"] + # the dataset is "covered" if its process is a subprocess of that of the dy dataset + if dataset_inst.has_process(self.dataset_inst.processes.get_first()): + self.process_ids_dy = process_ids_dy.derive(f"process_ids_dy_{name}", cls_dict={ + "dy_inclusive_dataset": dataset_inst, + "dy_leaf_processes": dy_cfg["leaf_processes"], + }) + + # add it as a dependency + self.uses.add(self.process_ids_dy) + self.produces.add(self.process_ids_dy) + + # stop after the first match + break diff --git a/hbt/util.py b/hbt/util.py index be57351..fa4a710 100644 --- a/hbt/util.py +++ b/hbt/util.py @@ -46,3 +46,14 @@ def IF_DATASET_HAS_LHE_WEIGHTS( return self.get() return None if func.dataset_inst.has_tag("no_lhe_weights") else self.get() + + +@deferred_column +def IF_DATASET_IS_DY( + self: ArrayFunction.DeferredColumn, + func: ArrayFunction, +) -> Any | set[Any]: + if getattr(func, "dataset_inst", None) is None: + return self.get() + + return self.get() if func.dataset_inst.has_tag("is_dy") else None diff --git a/modules/columnflow b/modules/columnflow index d30bde5..d4235dc 160000 --- a/modules/columnflow +++ b/modules/columnflow @@ -1 +1 @@ -Subproject commit d30bde57e59a22161366e9514bd02d8444a115b7 +Subproject commit d4235dca4f6e4ff7d2c0d319c62fc7570a5b43d7