-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adjusted the code structure, documented the code
- Loading branch information
1 parent
8edebf1
commit 1c0f3e1
Showing
1 changed file
with
269 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,122 +1,273 @@ | ||
import numpy as np | ||
from tmp import get_micro_metrics | ||
|
||
import numpy as np | ||
|
||
|
||
class Meter(object): | ||
def reset(self): | ||
pass | ||
def update(self,value): | ||
pass | ||
def get_update(self): | ||
pass | ||
|
||
class AverageMeter(Meter): | ||
|
||
def __init__(self): | ||
super(AverageMeter,self).__init__() | ||
self.reset() | ||
def reset(self): | ||
self.value = 0. | ||
self.average = 0. | ||
self.count = 0. | ||
def update(self,value): | ||
self.count += 1. | ||
self.value = value | ||
self.average = (self.average * (self.count - 1)+ self.value)/float(self.count) | ||
def get_update(self): | ||
return self.average | ||
|
||
class ShortTermMemoryMeter(Meter): | ||
def __init__(self,memory_length): | ||
super(ShortTermMemoryMeter,self).__init__() | ||
self.reset() | ||
self.memory_length = memory_length | ||
assert(self.memory_length >1) | ||
#to atleast average last 2 batch's results or else it would be only performing additional non-neccessary operations | ||
def reset(self): | ||
self.value = 0 | ||
self.length = 0 | ||
self.in_memory = [] | ||
self.average = 0 | ||
def update(self,value): | ||
self.value = value | ||
if (self.length >= self.memory_length): | ||
self.in_memory = self.in_memory[1:] | ||
self.in_memory.append(self.value) | ||
self.average = np.average(np.array(self.in_memory)) | ||
self.length = len(self.in_memory) | ||
def get_update(self): | ||
return self.average | ||
|
||
class Sampler(Meter): | ||
def __init__(self,value_names,rate): | ||
super(Sampler,self).__init__() | ||
self.reset() | ||
self.v_names = value_names | ||
self.rate = rate | ||
def reset(self): | ||
self.history=dict() | ||
for name in self.v_names: | ||
self.history[name] = list() | ||
def update(self,values,i): | ||
if(i%self.rate == 0): | ||
for j,k in enumerate(values): | ||
name = self.v_names[j] | ||
self.history[name].append(k) | ||
def get_update(self): | ||
return self.history | ||
|
||
class Meter(object): | ||
""" | ||
Meter is a general interface / class to use as a base for our classes | ||
With no implemented methods (to be implemented in children) | ||
""" | ||
def reset(self): | ||
""" | ||
Resets the meter to its initial state. | ||
""" | ||
|
||
def update(self,value): | ||
""" | ||
Updates the meter with a new value. | ||
Args: | ||
value: The value to update the meter with. | ||
""" | ||
|
||
def get_update(self): | ||
""" | ||
Gets the current state or value of the meter. | ||
Returns: | ||
The current state or value of the meter. | ||
""" | ||
|
||
class AverageMeter(Meter): | ||
""" | ||
Initializes an instance of the AverageMeter class subclass of Meter. | ||
""" | ||
def __init__(self): | ||
super(AverageMeter,self).__init__() | ||
self.value=0. | ||
self.average=0. | ||
self.count=0. | ||
|
||
def reset(self): | ||
""" | ||
Resets the AverageMeter to its initial state. | ||
""" | ||
self.value = 0. | ||
self.average = 0. | ||
self.count = 0. | ||
|
||
def update(self,value): | ||
""" | ||
Updates the AverageMeter with a new value. | ||
Args: | ||
value: The value to update the AverageMeter with. | ||
""" | ||
self.count += 1. | ||
self.value = value | ||
self.average = (self.average * (self.count - 1)+ self.value)/float(self.count) | ||
|
||
def get_update(self): | ||
""" | ||
Gets the current average value of the AverageMeter. | ||
Returns: | ||
The current average value of the AverageMeter. | ||
""" | ||
return self.average | ||
|
||
class ShortTermMemoryMeter(Meter): | ||
""" | ||
Initializes an instance of the ShortTermMemoryMeter class sublcass of Meter. | ||
Args: | ||
memory_length (int): Length of the short-term memory window. | ||
Should be greater than 1 for meaningful averaging. | ||
""" | ||
def __init__(self,memory_length): | ||
super(ShortTermMemoryMeter,self).__init__() | ||
self.value = 0 | ||
self.length = 0 | ||
self.in_memory = [] | ||
self.average = 0 | ||
self.memory_length = memory_length | ||
assert(self.memory_length >1) | ||
#to at least average last 2 batches results or else it would be only performing additional non-neccessary operations | ||
|
||
def reset(self): | ||
""" | ||
Resets the ShortTermMemoryMeter to its initial state. | ||
""" | ||
self.value = 0 | ||
self.length = 0 | ||
self.in_memory = [] | ||
self.average = 0 | ||
|
||
def update(self,value): | ||
""" | ||
Updates the ShortTermMemoryMeter with a new value. | ||
Args: | ||
value: The value to update the ShortTermMemoryMeter with. | ||
""" | ||
self.value = value | ||
if (self.length >= self.memory_length): | ||
self.in_memory = self.in_memory[1:] | ||
self.in_memory.append(self.value) | ||
self.average = np.average(np.array(self.in_memory)) | ||
self.length = len(self.in_memory) | ||
|
||
def get_update(self): | ||
""" | ||
Gets the current average value of the ShortTermMemoryMeter. | ||
Returns: | ||
The current average value of the ShortTermMemoryMeter. | ||
""" | ||
return self.average | ||
|
||
class Sampler(Meter): | ||
""" | ||
Initializes an instance of the Sampler class. | ||
Args: | ||
value_names (list): List of names for the sampled values. | ||
rate (int): Sampling rate, i.e., how often to store values in the history. | ||
""" | ||
def __init__(self,value_names,rate): | ||
super(Sampler,self).__init__() | ||
self.v_names = value_names | ||
self.rate = rate | ||
self.reset() | ||
|
||
classes = ['overall','background' , 'building_flooded' ,'building_non-flooded' , 'road_flooded' , 'road_non-flooded' , 'water' , 'tree' , 'vehicle' , 'pool' , 'grass'] | ||
def register_metrics_avg_meters(names = classes,metric_names = ['iou'],thresh = 0.5): | ||
l = len(names) | ||
metrics = get_micro_metrics( | ||
metrics=[metric_names] * l, | ||
threshs=[thresh] * l, | ||
channels=[None,*[i for i in range(l-1)]], | ||
names=classes, | ||
num_cls=l | ||
) | ||
tups = [] | ||
for metric in metrics: | ||
tups.append([metric,AverageMeter()]) | ||
return tups | ||
|
||
def get_meters_info(seg_loss_meter,seg_loss_name,cls_loss_meter,cls_loss_name, MnMs,cls_accuracy_meter,cls_f1score_meter): | ||
info = '' | ||
info += f'| {seg_loss_name} : {seg_loss_meter.get_update():.5} ' | ||
info += f'| {cls_loss_name} : {cls_loss_meter.get_update():.5} ' | ||
info += f'| cls_accuracy : {cls_accuracy_meter.get_update():.5} ' | ||
info += f'| cls_f1score : {cls_f1score_meter.get_update():.5} ' | ||
for metric,meter in MnMs: | ||
info += f'| {metric.__name__} : {meter.get_update():.5}' | ||
info += ' |' | ||
return info | ||
|
||
def register_scores(pdict,seg_loss_meter,seg_criterion,cls_loss_meter,cls_criterion, MnMs,cls_accuracy_meter,cls_f1score_meter,prefix = 'train'): | ||
assert prefix in ['train','val'],'{} is not a valid prefix'.format(prefix) | ||
seg_loss_name = f'{prefix}_seg_' + seg_criterion.__name__ | ||
cls_loss_name = f'{prefix}_cls_' + cls_criterion.__name__ | ||
keys = list(pdict.keys()) | ||
|
||
if(seg_loss_name not in keys): | ||
pdict[seg_loss_name] = [] | ||
pdict[cls_loss_name] = [] | ||
pdict['cls_accuracy'] = [] | ||
pdict['cls_f1score'] = [] | ||
|
||
pdict[seg_loss_name].append(seg_loss_meter.get_update()) | ||
pdict[cls_loss_name].append(cls_loss_meter.get_update()) | ||
pdict['cls_accuracy'].append(cls_accuracy_meter.get_update()) | ||
pdict['cls_f1score'].append(cls_f1score_meter.get_update()) | ||
|
||
for metric,meter in MnMs: | ||
if(f'{prefix}_' + metric.__name__ not in keys): | ||
pdict[f'{prefix}_' + metric.__name__] = [] | ||
pdict[f'{prefix}_' + metric.__name__].append(meter.get_update()) | ||
return pdict | ||
|
||
|
||
def reset(self): | ||
""" | ||
Resets the Sampler to its initial state. | ||
""" | ||
self.history=dict() | ||
for name in self.v_names: | ||
self.history[name] = list() | ||
|
||
|
||
def update(self,values,i): | ||
""" | ||
Updates the Sampler with a new set of values at a given iteration. | ||
Args: | ||
values (list): List of values to update the Sampler with. | ||
i (int): Current iteration or step. | ||
Note: | ||
- Values are stored in the history based on the specified sampling rate. | ||
""" | ||
if(i%self.rate == 0): | ||
for j,k in enumerate(values): | ||
name = self.v_names[j] | ||
self.history[name].append(k) | ||
|
||
def get_update(self): | ||
""" | ||
Gets the current history of sampled values. | ||
Returns: | ||
dict: Dictionary containing sampled values for each value name. | ||
""" | ||
return self.history | ||
|
||
class MetricsManager(): | ||
""" | ||
A class for managing and reporting various metrics during training or validation. | ||
""" | ||
classes = ['overall','background' , 'building_flooded' ,'building_non-flooded' , 'road_flooded' , | ||
'road_non-flooded' , 'water' , 'tree' , 'vehicle' , 'pool' , 'grass'] | ||
|
||
def __init__(self, classes, metric_names=['iou'], thresh=0.5): | ||
""" | ||
Initializes an instance of the MetricsManager class. | ||
Args: | ||
classes (list): List of class names or labels. | ||
metric_names (list, optional): List of metric names. Default is ['iou']. | ||
thresh (float, optional): Threshold value for metrics. Default is 0.5. | ||
""" | ||
self.classes = classes | ||
self.metric_names = metric_names | ||
self.thresh = thresh | ||
self.metrics = self.register_metrics_avg_meters() | ||
|
||
def register_metrics_avg_meters(self): | ||
""" | ||
Registers metric functions with AverageMeter instances. | ||
Returns: | ||
list: List of tuples containing metric functions and corresponding AverageMeter instances. | ||
""" | ||
l = len(self.classes) | ||
#get_micro_metrics is from the NoteBook which includes a lot of dependecies -> new file with all functions ? | ||
metrics = get_micro_metrics( | ||
metrics=[self.metric_names] * l, | ||
threshs=[self.thresh] * l, | ||
channels=[None,*[i for i in range(l-1)]], | ||
names=self.classes, | ||
num_cls=l | ||
) | ||
tups = [] | ||
for metric in metrics: | ||
tups.append([metric,AverageMeter()]) | ||
return tups | ||
|
||
@staticmethod | ||
def get_meters_info(seg_loss_meter,seg_loss_name,cls_loss_meter,cls_loss_name, mnms,cls_accuracy_meter,cls_f1score_meter): | ||
""" | ||
Gets a formatted string with information about various metrics. | ||
Args: | ||
seg_loss_meter (Meter): Meter instance for segmentation loss. | ||
seg_loss_name (str): Name for segmentation loss. | ||
cls_loss_meter (Meter): Meter instance for classification loss. | ||
cls_loss_name (str): Name for classification loss. | ||
mnms (list): List of tuples containing metric functions and corresponding Meter instances. | ||
cls_accuracy_meter (Meter): Meter instance for classification accuracy. | ||
cls_f1score_meter (Meter): Meter instance for classification F1 score. | ||
Returns: | ||
str: Formatted string with information about various metrics. | ||
""" | ||
info = '' | ||
info += f'| {seg_loss_name} : {seg_loss_meter.get_update():.5} ' | ||
info += f'| {cls_loss_name} : {cls_loss_meter.get_update():.5} ' | ||
info += f'| cls_accuracy : {cls_accuracy_meter.get_update():.5} ' | ||
info += f'| cls_f1score : {cls_f1score_meter.get_update():.5} ' | ||
for metric,meter in mnms: | ||
info += f'| {metric.__name__} : {meter.get_update():.5}' | ||
info += ' |' | ||
return info | ||
|
||
@staticmethod | ||
def register_scores(pdict,seg_loss_meter,seg_criterion,cls_loss_meter,cls_criterion, mnms,cls_accuracy_meter,cls_f1score_meter,prefix = 'train'): | ||
""" | ||
Registers scores for different metrics in a provided dictionary. | ||
Args: | ||
pdict (dict): Dictionary to store metric scores. | ||
seg_loss_meter (Meter): Meter instance for segmentation loss. | ||
seg_criterion (torch.nn.Module): Segmentation loss criterion. | ||
cls_loss_meter (Meter): Meter instance for classification loss. | ||
cls_criterion (torch.nn.Module): Classification loss criterion. | ||
mnms (list): List of tuples containing custom metrics and their corresponding Meter instances. | ||
cls_accuracy_meter (Meter): Meter instance for classification accuracy. | ||
cls_f1score_meter (Meter): Meter instance for classification F1 score. | ||
prefix (str, optional): Prefix for metric names (e.g., 'train' or 'val'). Default is 'train'. | ||
Returns: | ||
dict: Updated dictionary with metric scores. | ||
""" | ||
assert prefix in ['train','val'],'{} is not a valid prefix'.format(prefix) | ||
seg_loss_name = f'{prefix}_seg_' + seg_criterion.__name__ | ||
cls_loss_name = f'{prefix}_cls_' + cls_criterion.__name__ | ||
keys = list(pdict.keys()) | ||
|
||
if(seg_loss_name not in keys): | ||
pdict[seg_loss_name] = [] | ||
pdict[cls_loss_name] = [] | ||
pdict['cls_accuracy'] = [] | ||
pdict['cls_f1score'] = [] | ||
|
||
pdict[seg_loss_name].append(seg_loss_meter.get_update()) | ||
pdict[cls_loss_name].append(cls_loss_meter.get_update()) | ||
pdict['cls_accuracy'].append(cls_accuracy_meter.get_update()) | ||
pdict['cls_f1score'].append(cls_f1score_meter.get_update()) | ||
|
||
for metric,meter in mnms: | ||
if(f'{prefix}_' + metric.__name__ not in keys): | ||
pdict[f'{prefix}_' + metric.__name__] = [] | ||
pdict[f'{prefix}_' + metric.__name__].append(meter.get_update()) | ||
return pdict |