From 823475950513f38b2908a04f5795982ae821cc38 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Thu, 11 Jul 2024 14:34:18 +0200 Subject: [PATCH] Continue generalizing dy process ids. (#4) * Continue generalizing dy process ids. * Simplification. * Typo. * Typo. --- hbt/config/configs_run3.py | 7 +- hbt/production/processes.py | 123 ++++++++++++++++++++++++++---------- hbt/selection/default.py | 36 +++++++++-- modules/cmsdb | 2 +- 4 files changed, 124 insertions(+), 44 deletions(-) diff --git a/hbt/config/configs_run3.py b/hbt/config/configs_run3.py index 28367a5..f0a9865 100644 --- a/hbt/config/configs_run3.py +++ b/hbt/config/configs_run3.py @@ -276,10 +276,9 @@ def if_era( "sm_data": ["data"] + sm_group, } - # define available leaf dy processes (key=njet, value=list[tuple(min_pt, max_pt)]) - cfg.x.dy_pt_stitching_ranges = { - njet: [(0, 40), (40, 100), (100, 200), (200, 400), (400, 600), (600, float("inf"))] - for njet in [1, 2] + # define inclusive datasets for the dy process identification + cfg.x.dy_inclusive_datasets = { + "m50toinf": cfg.datasets.n.dy_m50toinf_amcatnlo, } # dataset groups for conveniently looping over certain datasets diff --git a/hbt/production/processes.py b/hbt/production/processes.py index f6ccadf..d000abb 100644 --- a/hbt/production/processes.py +++ b/hbt/production/processes.py @@ -4,63 +4,118 @@ Process ID producer relevant for the stitching of the DY samples. """ -import re +import functools + +import law from columnflow.production import Producer, producer -from columnflow.util import maybe_import +from columnflow.util import maybe_import, InsertableDict from columnflow.columnar_util import set_ak_column from hbt.util import IF_DATASET_IS_DY np = maybe_import("numpy") ak = maybe_import("awkward") +sp = maybe_import("scipy") +maybe_import("scipy.sparse") + + +logger = law.logger.get_logger(__name__) + +# helper +set_ak_column_i64 = functools.partial(set_ak_column, value_type=np.int64) @producer( - uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt", "LHE.*")}, - produces={"process_ids"}, + uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")}, + produces={IF_DATASET_IS_DY("process_ids")}, ) -def dy_process_ids(self: Producer, events: ak.Array, **kwargs) -> ak.Array: +def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array: """ - Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of the LHE record. - This is used for the stitching of the DY samples. + Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of + the LHE record. This is used for the stitching of the DY samples. """ - # trivial case + # as always, we assume that each dataset has exactly one process associated to it if len(self.dataset_inst.processes) != 1: raise NotImplementedError( f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes " "assigned, which is not yet implemented", ) - process = self.dataset_inst.processes.get_first() - jet_match = re.match(r"^.*(\d)j.*$", process.name) - if process.is_leaf_process: - process_id = process.id - # store the column - events = set_ak_column(events, "process_id", len(events) * [process_id], value_type=np.int32) - return events - - # get the LHE Njets and Vpt + # get the number of nlo jets and the di-lepton pt njets = events.LHE.NpNLO pt = events.LHE.Vpt - process_ids = np.zeros(len(events), dtype=np.int32) - if jet_match: - n = int(jet_match.groups()[0]) - for min_pt, max_pt in self.config_inst.x.dy_pt_stitching_ranges[n]: - process_ids[(pt >= min_pt) & (pt < max_pt)] = ( - self.config_inst.get_process(f"dy_m50toinf_{n}j_pt{min_pt}to{max_pt}").id + # raise a warning if a datasets was already created for a specific "bin" (leaf process), + # but actually does not fit + njets_nominal = self.dataset_inst.x("njets") + if njets_nominal is not None and ak.any(njets != njets_nominal): + logger.warning( + f"dataset {self.dataset_inst.name} is meant to contain only {njets_nominal} jets, " + f"but the LHE record contains different jet multiplicities: {set(njets)}", + ) + pt_range_nominal = self.dataset_inst.x("ptll") + if pt_range_nominal is not None: + outliers = (pt < pt_range_nominal[0]) | (pt >= pt_range_nominal[1]) + if ak.any(outliers): + logger.warning( + f"dataset {self.dataset_inst.name} is meant to contain ptZ values in the range " + f"{pt_range_nominal[0]} to {pt_range_nominal[1]}, but found {ak.sum(outliers)} " + "events outside this range", ) - else: - process_ids[njets == 0] = self.config_inst.get_process("dy_m50toinf_0j").id - process_ids[njets >= 3] = self.config_inst.get_process("dy_m50toinf_ge3j").id - for n, pt_ranges in self.config_inst.x.dy_pt_stitching_ranges.items(): - for min_pt, max_pt in pt_ranges: - process_ids[(njets == n) & (pt >= min_pt) & (pt < max_pt)] = ( - self.config_inst.get_process(f"dy_m50toinf_{n}j_pt{min_pt}to{max_pt}").id - ) - - # store the column - events = set_ak_column(events, "process_id", len(events) * [process_id], value_type=np.int32) + # get the LHE Njets and Vpt to assign each event into a leaf process using the lookup table + process_ids = np.array(self.id_table[0, self.key_func(njets, pt)].todense())[0] + events = set_ak_column_i64(events, "process_id", process_ids) + return events + + +@process_ids_dy.setup +def process_ids_dy_setup( + self: Producer, + reqs: dict, + inputs: dict, + reader_targets: InsertableDict, +) -> None: + # define stitching ranges for the DY datasets covered by this producer's dy_inclusive_dataset + # TODO: extract the following from the datasets' aux info + stitching_ranges = { + 0: None, + 1: [(0, 40), (40, 100), (100, 200), (200, 400), (400, 600), (600, float("inf"))], + 2: [(0, 40), (40, 100), (100, 200), (200, 400), (400, 600), (600, float("inf"))], + } + + # define a key function that maps njets and pt to a unique key for use in a lookup table + def key_func(njets, pt): + # potentially convert single values into arrays + single = False + if isinstance(njets, int) and isinstance(pt, (int, float)): + njets = np.array([njets], dtype=np.int32) + pt = np.array([pt], dtype=np.float32) + single = True + + # map into pt bins (index 0 means no binning) + pt_bin = np.zeros(len(pt), dtype=np.int32) + for nj, pt_ranges in stitching_ranges.items(): + if not pt_ranges: + continue + nj_mask = njets == nj + for i, (pt_min, pt_max) in enumerate(pt_ranges, 1): + pt_mask = (pt_min <= pt) & (pt < pt_max) + pt_bin[nj_mask & pt_mask] = i + # values larger than the largest configured njet value are set to 0 + pt_bin[njets >= (nj + 1)] = 0 + + # compute the key + key = njets * 100 + pt_bin + + return key[0] if single else key + + # save it + self.key_func = key_func + + # define the lookup table and fill it + max_nj = max(stitching_ranges.keys()) + 1 + self.id_table = sp.sparse.lil_matrix((1, key_func(max_nj, 0) + 1), dtype=np.int64) + # TODO: fill diff --git a/hbt/selection/default.py b/hbt/selection/default.py index e9d973b..57485aa 100644 --- a/hbt/selection/default.py +++ b/hbt/selection/default.py @@ -4,6 +4,8 @@ Selection methods. """ +from __future__ import annotations + from operator import and_ from functools import reduce from collections import defaultdict @@ -25,7 +27,7 @@ from hbt.selection.lepton import lepton_selection from hbt.selection.jet import jet_selection from hbt.production.features import cutflow_features -from hbt.production.processes import dy_process_ids +from hbt.production.processes import process_ids_dy from hbt.util import IF_DATASET_HAS_LHE_WEIGHTS np = maybe_import("numpy") @@ -36,12 +38,12 @@ uses={ json_filter, met_filters, trigger_selection, lepton_selection, jet_selection, mc_weight, pu_weight, btag_weights, process_ids, cutflow_features, increment_stats, - attach_coffea_behavior, dy_process_ids, + attach_coffea_behavior, IF_DATASET_HAS_LHE_WEIGHTS(pdf_weights, murmuf_weights), }, produces={ trigger_selection, lepton_selection, jet_selection, mc_weight, pu_weight, btag_weights, - process_ids, cutflow_features, increment_stats, dy_process_ids, + process_ids, cutflow_features, increment_stats, IF_DATASET_HAS_LHE_WEIGHTS(pdf_weights, murmuf_weights), }, sandbox=dev_sandbox("bash::$HBT_BASE/sandboxes/venv_columnar_tf.sh"), @@ -114,8 +116,8 @@ def default( ) # create process ids - if self.dataset_inst.has_tag("is_dy"): - events = self[dy_process_ids](events, **kwargs) + if self.process_ids_dy is not None: + events = self[self.process_ids_dy](events, **kwargs) else: events = self[process_ids](events, **kwargs) @@ -180,3 +182,27 @@ def default( ) return events, results + + +@default.init +def default_init(self: Selector) -> None: + if getattr(self, "dataset_inst", None) is None: + return + + self.process_ids_dy: process_ids_dy | None = None + if self.dataset_inst.has_tag("is_dy"): + # check if this dataset is covered by any dy id producer + for name, dataset_inst in self.config_inst.x.dy_inclusive_datasets.items(): + # the dataset is "covered" if its process is a subprocess of that of the dy dataset + if dataset_inst.has_process(self.dataset_inst.processes.get_first()): + self.process_ids_dy = process_ids_dy.derive( + f"process_ids_dy_{name}", + cls_dict={"dy_inclusive_dataset": dataset_inst}, + ) + + # add it as a dependency + self.uses.add(self.process_ids_dy) + self.produces.add(self.process_ids_dy) + + # stop after the first match + break diff --git a/modules/cmsdb b/modules/cmsdb index ca5abce..d3977b1 160000 --- a/modules/cmsdb +++ b/modules/cmsdb @@ -1 +1 @@ -Subproject commit ca5abce2996e4180a6c11171ad15c7d7a21f86f3 +Subproject commit d3977b16243b2585fedd93a4aa7f359aa9443b44