Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working base #389

Merged
merged 42 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ba81aae
adding feature option to graphnet/models/detector/icecube/IceCube86 c…
Aske-Rosted Nov 17, 2022
29ec105
Skipping frame if pulsemap missing
Aske-Rosted Dec 5, 2022
08d13f5
message passing in case of missing pulsemap
Aske-Rosted Dec 5, 2022
27a5cb9
added options for using differently named MCTree
Aske-Rosted Dec 5, 2022
dee0a9b
new implementation of feature renaming
Aske-Rosted Dec 5, 2022
72566b5
differently named feature in DOMCoarsening
Aske-Rosted Dec 5, 2022
29e8ebb
train_validation_test_dataloader
Aske-Rosted Dec 14, 2022
b7b416e
Add function for no selection criteria sql db
Aske-Rosted Jan 19, 2023
8a87fa8
Allow object type by forcing float with warning
Aske-Rosted Jan 23, 2023
079d419
Add extractor for comparisons to other algorithms
Aske-Rosted Jan 23, 2023
3e57582
Fixing according to pre-commit hooks
Aske-Rosted Jan 25, 2023
c52e051
remove @abstraclass
Aske-Rosted Jan 25, 2023
42d019e
allow for choosing parrellelising strategy
Aske-Rosted Jan 31, 2023
2a02ed5
Merge branch 'main' into working_base
Aske-Rosted Feb 1, 2023
662c5e1
Merge branch 'main' into working_base
Aske-Rosted Feb 3, 2023
4d9d4af
Update src/graphnet/models/detector/detector.py
Aske-Rosted Feb 15, 2023
adfd152
Update src/graphnet/models/detector/icecube.py
Aske-Rosted Feb 15, 2023
b856bac
Update src/graphnet/models/model.py
Aske-Rosted Feb 15, 2023
2d31e2e
Update src/graphnet/training/utils.py
Aske-Rosted Feb 15, 2023
038785c
Update src/graphnet/data/dataset.py
Aske-Rosted Feb 15, 2023
b8e2077
Update src/graphnet/data/extractors/i3splinempeextractor.py
Aske-Rosted Feb 15, 2023
7248ca8
Update src/graphnet/data/sqlite/sqlite_selection.py
Aske-Rosted Feb 15, 2023
cbf30c2
Update src/graphnet/models/coarsening.py
Aske-Rosted Feb 15, 2023
9149f1e
move comparison extractor to own file
Aske-Rosted Feb 15, 2023
54892bb
fixing typo and pre-commit hooks
Aske-Rosted Feb 15, 2023
2897349
Merge branch 'main' into working_base
Aske-Rosted Feb 15, 2023
5813957
remove unused import in src/graphnet/data/sqlite/sqlite_selection.py
Aske-Rosted Feb 21, 2023
5d0d377
class moved to other module
Aske-Rosted Feb 21, 2023
8f0c4e5
Update src/graphnet/models/coarsening.py inputs
Aske-Rosted Feb 21, 2023
53d1b23
use input instead
Aske-Rosted Feb 21, 2023
baf8c27
renaming src/graphnet/data/extractors/i3comparisonextractor.py
Aske-Rosted Feb 21, 2023
8d4e8a9
renaming comparison extractor in __init__.py
Aske-Rosted Feb 21, 2023
06f21a0
added null split frame check
Aske-Rosted Feb 21, 2023
c545490
comparisonextractor -> particleextractor
Aske-Rosted Feb 21, 2023
7a727e4
comparisonextractor -> particleextractor
Aske-Rosted Feb 21, 2023
7e172fe
rm exceptional handling of missing pulsemap
Aske-Rosted Feb 21, 2023
926f1e8
fix missing capitalization
Aske-Rosted Feb 21, 2023
0c8d54c
added additional values
Aske-Rosted Feb 21, 2023
e26a73c
streamlining error handling.
Aske-Rosted Feb 21, 2023
a050176
Merge branch 'main' into working_base
Aske-Rosted Feb 21, 2023
b659190
Update src/graphnet/data/extractors/i3particleextractor.py
Aske-Rosted Feb 21, 2023
695bd90
only accept I3 exceptions
Aske-Rosted Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,16 @@ def _extract_data(self, fileset: FileSet) -> List[OrderedDict]:
except: # noqa: E722
continue

# Extract data from I3Frame
results = self._extractors(frame)
# Try to extract data from I3Frame else skip frame
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
try:
results = self._extractors(frame)
except KeyError as e:
if "Pulsemap" in str(e):
self.warning(str(e) + ". Skipping frame")
continue
else:
raise e

data_dict = OrderedDict(zip(self._table_names, results))

# If an I3GenericExtractor is used, we want each automatically
Expand Down
10 changes: 9 additions & 1 deletion src/graphnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,15 @@ def _create_graph(
data = np.array([]).reshape((0, len(self._features) - 1))

# Construct graph data object
x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103
try:
x = torch.tensor(data, dtype=self._dtype) # pylint: disable=C0103
except TypeError as e:
x = torch.tensor(np.array(data, dtype=float), dtype=self._dtype)
self.warning(
r"The following type error was raised, forcing dtype float32. NaN's might occur in data! \n"
+ str(e)
)

Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
n_pulses = torch.tensor(len(x), dtype=torch.int32)
graph = Data(x=x, edge_index=None)
graph.n_pulses = n_pulses
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/data/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from .i3truthextractor import I3TruthExtractor
from .i3retroextractor import I3RetroExtractor
from .i3splinempeextractor import I3SplineMPEICExtractor
from .i3splinempeextractor import I3SplineMPEICExtractor, I3ComparisonExtractor
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
from .i3tumextractor import I3TUMExtractor
from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor
from .i3genericextractor import I3GenericExtractor
Expand Down
8 changes: 4 additions & 4 deletions src/graphnet/data/extractors/i3featureextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, List[Any]]:
}

# Get OM data
if self._pulsemap in frame:
try:
om_keys, data = get_om_keys_and_pulseseries(
frame,
self._pulsemap,
self._calibration,
)
else:
warn_once(self, f"Pulsemap {self._pulsemap} not found in frame.")
return output
except KeyError:
if self._pulsemap is not None:
raise KeyError(f"Pulsemap {self._pulsemap} not in frame")

# Added these :
bright_doms = None
Expand Down
24 changes: 23 additions & 1 deletion src/graphnet/data/extractors/i3splinempeextractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""I3Extractor class(es) for extracting SplineMPE reconstruction."""

from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved

from graphnet.data.extractors.i3extractor import I3Extractor

Expand Down Expand Up @@ -28,3 +28,25 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
)

return output


class I3ComparisonExtractor(I3Extractor):
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
"""Class for extracting pointing predictions from various algorithms."""

def __init__(self, name: str):
"""Construct I3ComparisonExtractor."""
# Base class constructor
super().__init__(name)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"""Extract pointing predictions."""
output = {}
if self._name in frame:
output.update(
{
"zenith_" + self._name: frame[self._name].dir.zenith,
"azimuth_" + self._name: frame[self._name].dir.azimuth,
}
)

return output
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
19 changes: 12 additions & 7 deletions src/graphnet/data/extractors/i3truthextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class I3TruthExtractor(I3Extractor):
"""Class for extracting truth-level information."""

def __init__(
self, name: str = "truth", borders: Optional[List[np.ndarray]] = None
self,
name: str = "truth",
borders: Optional[List[np.ndarray]] = None,
mctree: Optional[str] = "I3MCTree",
):
"""Construct I3TruthExtractor.

Expand All @@ -33,6 +36,7 @@ def __init__(
coordinates, for identifying, e.g., particles starting and
stopping within the detector. Defaults to hard-coded boundary
coordinates.
mctree: Str of which MCTree to use for truth values.
"""
# Base class constructor
super().__init__(name)
Expand Down Expand Up @@ -74,13 +78,14 @@ def __init__(
self._borders = [border_xy, border_z]
else:
self._borders = borders
self._mctree = mctree

def __call__(
self, frame: "icetray.I3Frame", padding_value: Any = -1
) -> Dict[str, Any]:
"""Extract truth-level information."""
is_mc = frame_is_montecarlo(frame)
is_noise = frame_is_noise(frame)
is_mc = frame_is_montecarlo(frame, self._mctree)
is_noise = frame_is_noise(frame, self._mctree)
sim_type = self._find_data_type(is_mc, self._i3_file)

output = {
Expand Down Expand Up @@ -217,7 +222,7 @@ def __call__(
def _extract_dbang_decay_length(
self, frame: "icetray.I3Frame", padding_value: float = -1
) -> float:
mctree = frame["I3MCTree"]
mctree = frame[self._mctree]
try:
p_true = mctree.primaries[0]
p_daughters = mctree.get_daughters(p_true)
Expand Down Expand Up @@ -346,11 +351,11 @@ def _get_primary_particle_interaction_type_and_elasticity(
try:
MCInIcePrimary = frame["MCInIcePrimary"]
except KeyError:
MCInIcePrimary = frame["I3MCTree"][0]
MCInIcePrimary = frame[self._mctree][0]
if (
MCInIcePrimary.energy != MCInIcePrimary.energy
): # This is a nan check. Only happens for some muons where second item in MCTree is primary. Weird!
MCInIcePrimary = frame["I3MCTree"][
MCInIcePrimary = frame[self._mctree][
1
] # For some strange reason the second entry is identical in all variables and has no nans (always muon)
else:
Expand Down Expand Up @@ -380,7 +385,7 @@ def _get_primary_track_energy_and_inelasticity(
Tuple containing the energy of tracks from primary, and the
corresponding inelasticity.
"""
mc_tree = frame["I3MCTree"]
mc_tree = frame[self._mctree]
primary = mc_tree.primaries[0]
daughters = mc_tree.get_daughters(primary)
tracks = []
Expand Down
12 changes: 8 additions & 4 deletions src/graphnet/data/extractors/utilities/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
) # pyright: reportMissingImports=false


def frame_is_montecarlo(frame: "icetray.I3Frame") -> bool:
def frame_is_montecarlo(
frame: "icetray.I3Frame", mctree: Optional[str] = "I3MCTree"
) -> bool:
"""Check whether `frame` is from Monte Carlo simulation."""
return ("MCInIcePrimary" in frame) or ("I3MCTree" in frame)
return ("MCInIcePrimary" in frame) or (mctree in frame)


def frame_is_noise(frame: "icetray.I3Frame") -> bool:
def frame_is_noise(
frame: "icetray.I3Frame", mctree: Optional[str] = "I3MCTree"
) -> bool:
"""Check whether `frame` is from noise."""
try:
frame["I3MCTree"][0].energy
frame[mctree][0].energy
return False
except: # noqa: E722
try:
Expand Down
28 changes: 27 additions & 1 deletion src/graphnet/data/sqlite/sqlite_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,38 @@

import numpy as np
import pandas as pd

import random
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
from graphnet.utilities.logging import get_logger

logger = get_logger()


def get_random_fraction(
database: str, fraction: float, seed: int = 0
) -> List[int]:
"""Get a random unique fraction without any special selection criteria.

Args:
database: Path to database from which to get event numbers.
fraction: desired fraction of the total number of events.
seed: Random number generator seed.

Returns:
List of event numbers.
"""
rng = np.random.RandomState(seed=seed)
with sqlite3.connect(database) as con:
total_query = "SELECT event_no FROM truth"

tmp_dataframe = pd.read_sql(total_query, con)
event_no_list = (
tmp_dataframe.sample(frac=fraction, replace=False, random_state=rng)
.values.ravel()
.tolist()
)
return event_no_list


Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
def get_desired_event_numbers(
database: str,
desired_size: int,
Expand Down
36 changes: 29 additions & 7 deletions src/graphnet/models/coarsening.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,25 @@ def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
class DOMCoarsening(Coarsening):
"""Coarsen pulses to DOM-level."""

def __init__(
self,
reduce: str = "avg",
transfer_attributes: bool = True,
keys: List[str] = [
"dom_x",
"dom_y",
"dom_z",
"rde",
"pmt_area",
],
):
"""Cluster pulses on the same DOM."""
super().__init__(reduce, transfer_attributes)
self._keys = keys

Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
"""Cluster nodes in `data` by assigning a cluster index to each."""
dom_index = group_by(
data, ["dom_x", "dom_y", "dom_z", "rde", "pmt_area"]
)
dom_index = group_by(data, self._keys)
return dom_index


Expand Down Expand Up @@ -271,23 +285,31 @@ def __init__(
time_window: float,
reduce: str = "avg",
transfer_attributes: bool = True,
keys: List[str] = [
"dom_x",
"dom_y",
"dom_z",
"rde",
"pmt_area",
],
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
):
"""Cluster pulses on the same DOM within `time_window`."""
super().__init__(reduce, transfer_attributes)
self._time_window = time_window
self._cluster_method = DBSCAN(self._time_window, min_samples=1)
self._keys = keys
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved

def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
"""Cluster nodes in `data` by assigning a cluster index to each."""
dom_index = group_by(
data, ["dom_x", "dom_y", "dom_z", "rde", "pmt_area"]
)
dom_index = group_by(data, self._keys)
if data.batch is not None:
features = data.features[0]
else:
features = data.features

ix_time = features.index("dom_time")
ix_time = [idx for idx, s in enumerate(features) if "time" in s][
0
] # features.index("dom_time")
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
hit_times = data.x[:, ix_time]

# Scale up dom_index to make sure clusters are well separated
Expand Down
4 changes: 4 additions & 0 deletions src/graphnet/models/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@ def _validate_features(self, data: Data) -> None:
assert (
data_features == self.features
), f"Features on Data and Detector differ: {data_features} vs. {self.features}"

def rename_features(self, features: List[str]) -> None:
"""Reassigning list of features to given Detctor object."""
self.features = features
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions src/graphnet/models/detector/icecube.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from graphnet.data.constants import FEATURES
from graphnet.models.detector.detector import Detector
from typing import Any, Dict, List, Optional, Tuple, Union
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved


class IceCube86(Detector):
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _construct_trainers(
logger: Optional[Logger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
strategy: Optional[str] = "ddp",
Aske-Rosted marked this conversation as resolved.
Show resolved Hide resolved
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> None:
Expand Down
Loading