-
Notifications
You must be signed in to change notification settings - Fork 0
/
SARCDataset.py
46 lines (33 loc) · 1.01 KB
/
SARCDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
from preprocessing import preprocess
from torch.utils.data import Dataset
class SARCDataset(Dataset):
def __init__(self, X, y, tokenizer):
texts = X
texts = [preprocess(text) for text in texts]
self._print_random_samples(texts)
self.texts = [
tokenizer(
text,
padding="max_length",
max_length=150,
truncation=True,
return_tensors="pt",
)
for text in texts
]
self.labels = y
def _print_random_samples(self, texts):
np.random.seed(42)
random_entries = np.random.randint(0, len(texts), 5)
for i in random_entries:
print(f"Entry {i}: {texts[i]}")
print()
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = -1
if hasattr(self, "labels"):
label = self.labels[idx]
return text, label