-
Notifications
You must be signed in to change notification settings - Fork 2
/
augmentations.py
57 lines (45 loc) · 1.42 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import albumentations as albu
import albumentations.pytorch as apt
def get_training_augmentations(
m = [0,0,0], s = [1,1,1]
):
train_transform = [
albu.HorizontalFlip(p=0.5),
# albu.VerticalFlip(p=0.5),
albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
albu.GaussNoise(p=0.2),
# albu.Perspective(p=0.5),
albu.OneOf(
[
albu.CLAHE(p=1),
albu.RandomBrightness(p=1),
albu.RandomGamma(p=1),
],
p=0.9,
),
albu.OneOf(
[
# albu.Sharpen(p=1),
albu.Blur(blur_limit=3, p=1),
albu.MotionBlur(blur_limit=3, p=1),
],
p=0.9,
),
albu.OneOf(
[
albu.RandomContrast(p=1),
albu.HueSaturationValue(p=1),
],
p=0.9,
),
albu.Normalize(mean = m, std = s),
apt.ToTensorV2(),
]
return albu.Compose(train_transform, additional_targets={'t2': 'image', 'mask3d': 'mask'})
def get_validation_augmentations(
m = [0,0,0], s = [1,1,1]):
train_transform = [
albu.Normalize(mean = m, std = s),
apt.ToTensorV2(),
]
return albu.Compose(train_transform, additional_targets={'t2': 'image', 'mask3d': 'mask'})