Skip to content

Commit

Permalink
Continue generalizing dy process ids. (uhh-cms#4)
Browse files Browse the repository at this point in the history
* Continue generalizing dy process ids.

* Simplification.

* Typo.

* Typo.
  • Loading branch information
riga authored Jul 11, 2024
1 parent f6a65c7 commit 8234759
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 44 deletions.
7 changes: 3 additions & 4 deletions hbt/config/configs_run3.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,9 @@ def if_era(
"sm_data": ["data"] + sm_group,
}

# define available leaf dy processes (key=njet, value=list[tuple(min_pt, max_pt)])
cfg.x.dy_pt_stitching_ranges = {
njet: [(0, 40), (40, 100), (100, 200), (200, 400), (400, 600), (600, float("inf"))]
for njet in [1, 2]
# define inclusive datasets for the dy process identification
cfg.x.dy_inclusive_datasets = {
"m50toinf": cfg.datasets.n.dy_m50toinf_amcatnlo,
}

# dataset groups for conveniently looping over certain datasets
Expand Down
123 changes: 89 additions & 34 deletions hbt/production/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,118 @@
Process ID producer relevant for the stitching of the DY samples.
"""

import re
import functools

import law

from columnflow.production import Producer, producer
from columnflow.util import maybe_import
from columnflow.util import maybe_import, InsertableDict
from columnflow.columnar_util import set_ak_column

from hbt.util import IF_DATASET_IS_DY

np = maybe_import("numpy")
ak = maybe_import("awkward")
sp = maybe_import("scipy")
maybe_import("scipy.sparse")


logger = law.logger.get_logger(__name__)

# helper
set_ak_column_i64 = functools.partial(set_ak_column, value_type=np.int64)


@producer(
uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt", "LHE.*")},
produces={"process_ids"},
uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")},
produces={IF_DATASET_IS_DY("process_ids")},
)
def dy_process_ids(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
"""
Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of the LHE record.
This is used for the stitching of the DY samples.
Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of
the LHE record. This is used for the stitching of the DY samples.
"""
# trivial case
# as always, we assume that each dataset has exactly one process associated to it
if len(self.dataset_inst.processes) != 1:
raise NotImplementedError(
f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
"assigned, which is not yet implemented",
)
process = self.dataset_inst.processes.get_first()
jet_match = re.match(r"^.*(\d)j.*$", process.name)

if process.is_leaf_process:
process_id = process.id
# store the column
events = set_ak_column(events, "process_id", len(events) * [process_id], value_type=np.int32)
return events

# get the LHE Njets and Vpt
# get the number of nlo jets and the di-lepton pt
njets = events.LHE.NpNLO
pt = events.LHE.Vpt
process_ids = np.zeros(len(events), dtype=np.int32)

if jet_match:
n = int(jet_match.groups()[0])
for min_pt, max_pt in self.config_inst.x.dy_pt_stitching_ranges[n]:
process_ids[(pt >= min_pt) & (pt < max_pt)] = (
self.config_inst.get_process(f"dy_m50toinf_{n}j_pt{min_pt}to{max_pt}").id
# raise a warning if a datasets was already created for a specific "bin" (leaf process),
# but actually does not fit
njets_nominal = self.dataset_inst.x("njets")
if njets_nominal is not None and ak.any(njets != njets_nominal):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain only {njets_nominal} jets, "
f"but the LHE record contains different jet multiplicities: {set(njets)}",
)
pt_range_nominal = self.dataset_inst.x("ptll")
if pt_range_nominal is not None:
outliers = (pt < pt_range_nominal[0]) | (pt >= pt_range_nominal[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain ptZ values in the range "
f"{pt_range_nominal[0]} to {pt_range_nominal[1]}, but found {ak.sum(outliers)} "
"events outside this range",
)

else:
process_ids[njets == 0] = self.config_inst.get_process("dy_m50toinf_0j").id
process_ids[njets >= 3] = self.config_inst.get_process("dy_m50toinf_ge3j").id
for n, pt_ranges in self.config_inst.x.dy_pt_stitching_ranges.items():
for min_pt, max_pt in pt_ranges:
process_ids[(njets == n) & (pt >= min_pt) & (pt < max_pt)] = (
self.config_inst.get_process(f"dy_m50toinf_{n}j_pt{min_pt}to{max_pt}").id
)

# store the column
events = set_ak_column(events, "process_id", len(events) * [process_id], value_type=np.int32)
# get the LHE Njets and Vpt to assign each event into a leaf process using the lookup table
process_ids = np.array(self.id_table[0, self.key_func(njets, pt)].todense())[0]
events = set_ak_column_i64(events, "process_id", process_ids)

return events


@process_ids_dy.setup
def process_ids_dy_setup(
self: Producer,
reqs: dict,
inputs: dict,
reader_targets: InsertableDict,
) -> None:
# define stitching ranges for the DY datasets covered by this producer's dy_inclusive_dataset
# TODO: extract the following from the datasets' aux info
stitching_ranges = {
0: None,
1: [(0, 40), (40, 100), (100, 200), (200, 400), (400, 600), (600, float("inf"))],
2: [(0, 40), (40, 100), (100, 200), (200, 400), (400, 600), (600, float("inf"))],
}

# define a key function that maps njets and pt to a unique key for use in a lookup table
def key_func(njets, pt):
# potentially convert single values into arrays
single = False
if isinstance(njets, int) and isinstance(pt, (int, float)):
njets = np.array([njets], dtype=np.int32)
pt = np.array([pt], dtype=np.float32)
single = True

# map into pt bins (index 0 means no binning)
pt_bin = np.zeros(len(pt), dtype=np.int32)
for nj, pt_ranges in stitching_ranges.items():
if not pt_ranges:
continue
nj_mask = njets == nj
for i, (pt_min, pt_max) in enumerate(pt_ranges, 1):
pt_mask = (pt_min <= pt) & (pt < pt_max)
pt_bin[nj_mask & pt_mask] = i
# values larger than the largest configured njet value are set to 0
pt_bin[njets >= (nj + 1)] = 0

# compute the key
key = njets * 100 + pt_bin

return key[0] if single else key

# save it
self.key_func = key_func

# define the lookup table and fill it
max_nj = max(stitching_ranges.keys()) + 1
self.id_table = sp.sparse.lil_matrix((1, key_func(max_nj, 0) + 1), dtype=np.int64)
# TODO: fill
36 changes: 31 additions & 5 deletions hbt/selection/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Selection methods.
"""

from __future__ import annotations

from operator import and_
from functools import reduce
from collections import defaultdict
Expand All @@ -25,7 +27,7 @@
from hbt.selection.lepton import lepton_selection
from hbt.selection.jet import jet_selection
from hbt.production.features import cutflow_features
from hbt.production.processes import dy_process_ids
from hbt.production.processes import process_ids_dy
from hbt.util import IF_DATASET_HAS_LHE_WEIGHTS

np = maybe_import("numpy")
Expand All @@ -36,12 +38,12 @@
uses={
json_filter, met_filters, trigger_selection, lepton_selection, jet_selection, mc_weight,
pu_weight, btag_weights, process_ids, cutflow_features, increment_stats,
attach_coffea_behavior, dy_process_ids,
attach_coffea_behavior,
IF_DATASET_HAS_LHE_WEIGHTS(pdf_weights, murmuf_weights),
},
produces={
trigger_selection, lepton_selection, jet_selection, mc_weight, pu_weight, btag_weights,
process_ids, cutflow_features, increment_stats, dy_process_ids,
process_ids, cutflow_features, increment_stats,
IF_DATASET_HAS_LHE_WEIGHTS(pdf_weights, murmuf_weights),
},
sandbox=dev_sandbox("bash::$HBT_BASE/sandboxes/venv_columnar_tf.sh"),
Expand Down Expand Up @@ -114,8 +116,8 @@ def default(
)

# create process ids
if self.dataset_inst.has_tag("is_dy"):
events = self[dy_process_ids](events, **kwargs)
if self.process_ids_dy is not None:
events = self[self.process_ids_dy](events, **kwargs)
else:
events = self[process_ids](events, **kwargs)

Expand Down Expand Up @@ -180,3 +182,27 @@ def default(
)

return events, results


@default.init
def default_init(self: Selector) -> None:
if getattr(self, "dataset_inst", None) is None:
return

self.process_ids_dy: process_ids_dy | None = None
if self.dataset_inst.has_tag("is_dy"):
# check if this dataset is covered by any dy id producer
for name, dataset_inst in self.config_inst.x.dy_inclusive_datasets.items():
# the dataset is "covered" if its process is a subprocess of that of the dy dataset
if dataset_inst.has_process(self.dataset_inst.processes.get_first()):
self.process_ids_dy = process_ids_dy.derive(
f"process_ids_dy_{name}",
cls_dict={"dy_inclusive_dataset": dataset_inst},
)

# add it as a dependency
self.uses.add(self.process_ids_dy)
self.produces.add(self.process_ids_dy)

# stop after the first match
break

0 comments on commit 8234759

Please sign in to comment.