diff --git a/src/explib/datasets.py b/src/explib/datasets.py index 8a0941b..cdfc037 100644 --- a/src/explib/datasets.py +++ b/src/explib/datasets.py @@ -21,8 +21,9 @@ class DequantizedDataset(torch.utils.data.Dataset): def __init__( self, dataset: T.Union[os.PathLike, torch.utils.data.Dataset, np.ndarray], + labels: T.Union[np.ndarray, torch.Tensor], num_bits: int = 8, - device: torch.device = None, + device: torch.device = "cpu", ): if isinstance(dataset, torch.utils.data.Dataset) or isinstance( dataset, np.ndarray @@ -31,8 +32,15 @@ def __init__( else: self.dataset = pd.read_csv(dataset).values - # + if not isinstance(self.dataset, torch.Tensor): + self.dataset = torch.tensor(self.dataset) + self.dataset = self.dataset.to(device) + + if not isinstance(labels, torch.Tensor): + labels = torch.Tensor(labels) + self.labels = labels.to(device) + self.num_bits = num_bits self.num_levels = 2**num_bits self.transform = transforms.Compose( @@ -43,9 +51,9 @@ def __init__( ) def __getitem__(self, index: int): - x, y = self.dataset[index] + x = self.dataset[index] x = Tensor(self.transform(x)) - return x, y + return x, self.labels[index] def __len__(self): return len(self.dataset) @@ -241,7 +249,8 @@ def __init__( dataloc: os.PathLike = None, train: bool = True, label: T.Optional[int] = None, - scale: bool = False + scale: bool = False, + device: torch.device = "cpu" ): rel_path = ( "FashionMNIST/raw/train-images-idx3-ubyte" @@ -256,21 +265,28 @@ def __init__( if scale: dataset = dataset[:, ::3, ::3] dataset = dataset.reshape(dataset.shape[0], -1) - if label is not None: - rel_path = ( + + rel_path = ( "FashionMNIST/raw/train-labels-idx1-ubyte" if train else "FashionMNIST/raw/t10k-labels-idx1-ubyte" ) - path = os.path.join(dataloc, rel_path) - labels = idx2numpy.convert_from_file(path) + path = os.path.join(dataloc, rel_path) + labels = idx2numpy.convert_from_file(path) + + if label is not None: dataset = dataset[labels == label] - super().__init__(dataset, num_bits=8) + labels = labels[labels == label] + + super().__init__(dataset, torch.Tensor(labels), num_bits=8, device=device) def __getitem__(self, index: int): - x = Tensor(self.dataset[index].copy()) + if not isinstance(self.dataset, torch.Tensor): + x = Tensor(self.dataset[index].copy()) + else: + x = self.dataset[index] x = self.transform(x) - return x, 0 + return x, self.labels[index] class FashionMnistSplit(DataSplit): @@ -279,11 +295,13 @@ def __init__( dataloc: os.PathLike = None, val_split: float = 0.1, label: T.Optional[int] = None, + device: torch.device = "cpu" ): + self.label = label if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc - self.train = FashionMnistDequantized(self.dataloc, train=True, label=label) + self.train = FashionMnistDequantized(self.dataloc, train=True, label=label, device=device) shuffle = torch.randperm(len(self.train)) self.val = torch.utils.data.Subset( self.train, shuffle[: int(len(self.train) * val_split)] @@ -291,7 +309,7 @@ def __init__( self.train = torch.utils.data.Subset( self.train, shuffle[int(len(self.train) * val_split) :] ) - self.test = FashionMnistDequantized(self.dataloc, train=False, label=label) + self.test = FashionMnistDequantized(self.dataloc, train=False, label=label, device=device) def get_train(self) -> torch.utils.data.Dataset: return self.train @@ -312,7 +330,7 @@ def __init__( digit: T.Optional[int] = None, flatten=True, scale: bool = False, - device: torch.device = None + device: torch.device = "cpu" ): if train: rel_path = "MNIST/raw/train-images-idx3-ubyte" @@ -323,19 +341,24 @@ def __init__( MNIST(dataloc, train=train, download=True) dataset = idx2numpy.convert_from_file(path) + if scale: dataset = dataset[:, ::3, ::3] if flatten: dataset = dataset.reshape(dataset.shape[0], -1) + + if train: + rel_path = "MNIST/raw/train-labels-idx1-ubyte" + else: + rel_path = "MNIST/raw/t10k-labels-idx1-ubyte" + path = os.path.join(dataloc, rel_path) + labels = idx2numpy.convert_from_file(path) + if digit is not None: - if train: - rel_path = "MNIST/raw/train-labels-idx1-ubyte" - else: - rel_path = "MNIST/raw/t10k-labels-idx1-ubyte" - path = os.path.join(dataloc, rel_path) - labels = idx2numpy.convert_from_file(path) dataset = dataset[labels == digit] - super().__init__(torch.Tensor(dataset), num_bits=8, device=device) + labels = labels[labels == digit] + + super().__init__(torch.Tensor(dataset), torch.Tensor(labels), num_bits=8, device=device) def __getitem__(self, index: int): if not isinstance(self.dataset, torch.Tensor): @@ -343,7 +366,7 @@ def __getitem__(self, index: int): else: x = self.dataset[index] x = self.transform(x) - return x, 0 + return x, self.labels[index] class MnistSplit(DataSplit): def __init__( @@ -352,8 +375,9 @@ def __init__( val_split: float = 0.1, digit: T.Optional[int] = None, scale: bool = False, - device: torch.device = None + device: torch.device = "cpu" ): + self.digit = digit if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc @@ -393,3 +417,4 @@ def __init__( if not os.path.exists(path): CIFAR10(dataloc, train=train, download=True) +