= make_periodic_dataset(timepoints=100, extrap=True, max_t=5.0, n=200, noise_weight=0.01)
time, observations time.shape, observations.shape
(torch.Size([101]), torch.Size([200, 101, 1]))
Except for the first and last lines, everything else comes from Rubanova’s implementation (comments mine)
make_periodic_dataset (timepoints:int, extrap:bool, max_t:float, n:int, noise_weight:float)
Type | Details | |
---|---|---|
timepoints | int | Number of time instants |
extrap | bool | Whether extrapolation is peformed |
max_t | float | Maximum value of time instants |
n | int | Number of examples |
noise_weight | float | Standard deviation of the noise to be added |
A class defining a (somehow complex) collate function for a PyTorch DataLoader
CollateFunction (time:torch.Tensor, n_points_to_subsample=None)
Initialize self. See help(type(self)) for accurate signature.
Type | Default | Details | |
---|---|---|---|
time | Tensor | Time axis [time] | |
n_points_to_subsample | NoneType | None | Number of points to be “subsampled” |
Let us build an object for testing
Collate function expecting time series of length 101, with the second half to be predicted from the first.
We also need a PyTorch DataLoader
dataloader = torch.utils.data.DataLoader(observations, batch_size = 10, shuffle=False, collate_fn=collate_fn)
dataloader
<torch.utils.data.dataloader.DataLoader>
How many batches is this DataLoader
providing?
Let us get the first batch
Notice that, as seen from CollateFunction.__call__
function’s prototype, the type is returned is a dictionary. It contains the following fields
dict_keys(['observed_time', 'observed_data', 'to_predict_at_time', 'to_predict_data', 'observed_mask'])
observed_time
and observed_data
is the first part of a time series we want to learn, whereasto_predict_at_time
, to_predict_data
is the second part of the same time series we aim at predicting; on the other handobserved_mask
is True
for every observation that is available (it only applies to the observed data)If one must think of this in terms of an input, \(x\), that is given, and a related output, \(y\), that is to be predicted, the latter would be to_predict_data
and the former would encompass the rest of the fields.
We can check the size of every component
Dimensions of observed_time: (50,)
Dimensions of observed_data: (10, 50, 1)
Dimensions of to_predict_at_time: (51,)
Dimensions of to_predict_data: (10, 51, 1)
Dimensions of observed_mask: (10, 50, 1)
In this simple example, every observatios is available
If one wants to move this object to another device, this function will do that for all the relevant internal state.
CollateFunction.to (device)