Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShufflingSampler can lead to significantly different free energies compared to default sampler #40

Open
jdrefs opened this issue Mar 23, 2023 · 1 comment

Comments

@jdrefs
Copy link
Member

jdrefs commented Mar 23, 2023

In distributed execution, the ShufflingSampler potentially samples duplicate data points to ensure synchronized batch processing on each worker. The duplicate data points contribute twice to E- and M-step.

In its current version, the Trainer includes terms associated to duplicate data points when evaluating free energies (e.g., here, here, here, here). This can lead to significantly different results compared to a sequential execution without the ShufflingSampler and hence without duplicate datapoints (e.g., for a SSSC-House benchmark (\sigma=50, D=144, H=512, |K|=30), I observed a free energy difference on the order of 7).

Furthermore, the additional data points lead to additional terms for the Theta updates (in update_param_batch methods), s..t. different M-step results compared to the sequential execution setting are obtained.

@jdrefs
Copy link
Member Author

jdrefs commented Mar 23, 2023

As a suggestion for a workaround, we could include an index tensor in the dataset to indicate whether a given datapoint corresponds to a duplicate or not:

class TVODataLoader(DataLoader):
    def __init__(self, data: to.Tensor, **kwargs):
        """TVO DataLoader class. Derived from torch.utils.data.DataLoader.

        :param data: Tensor containing the input dataset. Must have exactly two dimensions (N,D).
        :param kwargs: forwarded to pytorch's DataLoader.

        TVODataLoader is constructed exactly the same way as pytorch's DataLoader,
        but it restricts datasets to TensorDataset constructed from the data passed
        as parameter. All other arguments are forwarded to pytorch's DataLoader.

        In the case of distributed execution with unevenly sized datasets per worked,
        TVODataLoader will sample a few datapoints twice to guarantee that each 
        worker iterates over the same number of batches. 

        When iterated over, TVODataLoader yields a tuple containing the indeces of
        the datapoints in each batch, the actual datapoints as well as an index
        tensor indicating whether a datapoint corresponds to a duplicate or not.

        TVODataLoader instances optionally expose the attribute `precision`, which is set to the
        dtype of the dataset in data if it is a floating point dtype.
        """
        N = data.shape[0]

        if data.dtype is not to.uint8:
            self.precision = data.dtype

        notduplicate = to.ones(N, dtype=to.bool)
        if tvo.get_run_policy() == "mpi":
            assert dist.is_initialized()
            # Ranks ..., (comm_size-2), (comm_size-1) are assigned one data point more than ranks
            # 0, 1, ... if the dataset cannot be evenly distributed across MPI processes (the split
            # point depends on the total number of data points and number of MPI processes; cf.
            # scatter_to_processes, gather_from_processes).
            # To ensure that all workers can loop over batches in sync, we assign the processes
            # with fewer datapoints, one randomly sampled additional datapoint, and we mark these
            # additional datapoints as duplicates (s.t. models can optionally neglect it)
            n_samples = to.tensor(N)
            comm_size = dist.get_world_size()
            broadcast(n_samples, src=comm_size - 1)
            n_extra_samples = n_samples.item() - N
            if n_extra_samples > 0:
                assert n_extra_samples == 1  # by definition (cf. scatter_to_processes), the amount
                # of datapoints on different MPI ranks should not differ
                # by more than one
                replace = True if n_extra_samples > N else False  # should always be False
                idxs_repeat = np.random.choice(N, size=n_extra_samples, replace=replace)
                data = to.cat((data, data[idxs_repeat]), dim=0)
                notduplicate = to.cat(
                    (notduplicate, to.zeros(n_extra_samples, dtype=to.bool)), dim=0
                )

        dataset = TensorDataset(to.arange(data.shape[0]), data, notduplicate)

        super().__init__(dataset, **kwargs)

The Trainer could then use the notduplicate index tensor when computing free energies:

# in Trainer._train_epoch
for idx, batch, notduplicate in train_data:
    # ...
    batch_F = model.free_energy(idx[notduplicate], batch[notduplicate], train_states)
    # or alternatively, internally handles the index tensor
    batch_F = model.free_energy(idx, batch, train_states, notduplicate=notduplicate)

One would additionally need to make sure, that the Trainer correctly infers the number of datapoints:

# in Trainer.__init__
# ...
notduplicate = train_data.dataset.tensors[2] 
N_train = to.tensor(notduplicate.sum().item())

Similarly, the Trainer could pass notduplicate to model.update_param_batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant