|  | 
|  | 1 | +__all__ = ["AutoRegressiveBaseModel"] | 
|  | 2 | + | 
|  | 3 | +from loguru import logger | 
|  | 4 | +from typing import List, Union, Any, Sequence, Tuple, Dict, Callable | 
|  | 5 | + | 
|  | 6 | +import torch | 
|  | 7 | +from torch import Tensor | 
|  | 8 | + | 
|  | 9 | +from pytorch_forecasting.metrics import MultiLoss, DistributionLoss | 
|  | 10 | +from pytorch_forecasting.utils import to_list, apply_to_list | 
|  | 11 | +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel as AutoRegressiveBaseModel_ | 
|  | 12 | + | 
|  | 13 | + | 
|  | 14 | +class AutoRegressiveBaseModel(AutoRegressiveBaseModel_):  # pylint: disable=abstract-method | 
|  | 15 | +    """Basically AutoRegressiveBaseModel from `pytorch_forecasting` but fixed for multi-target. Worked for `LSTM`.""" | 
|  | 16 | + | 
|  | 17 | +    def output_to_prediction( | 
|  | 18 | +        self, | 
|  | 19 | +        normalized_prediction_parameters: torch.Tensor, | 
|  | 20 | +        target_scale: Union[List[torch.Tensor], torch.Tensor], | 
|  | 21 | +        n_samples: int = 1, | 
|  | 22 | +        **kwargs: Any, | 
|  | 23 | +    ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: | 
|  | 24 | +        """ | 
|  | 25 | +        Convert network output to rescaled and normalized prediction. | 
|  | 26 | +        Function is typically not called directly but via :py:meth:`~decode_autoregressive`. | 
|  | 27 | +        Args: | 
|  | 28 | +            normalized_prediction_parameters (torch.Tensor): network prediction output | 
|  | 29 | +            target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale to rescale network output | 
|  | 30 | +            n_samples (int, optional): Number of samples to draw independently. Defaults to 1. | 
|  | 31 | +            **kwargs: extra arguments for dictionary passed to :py:meth:`~transform_output` method. | 
|  | 32 | +        Returns: | 
|  | 33 | +            Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: tuple of rescaled prediction and | 
|  | 34 | +                normalized prediction (e.g. for input into next auto-regressive step) | 
|  | 35 | +        """ | 
|  | 36 | +        logger.trace(f"normalized_prediction_parameters={normalized_prediction_parameters.size()}") | 
|  | 37 | +        B = normalized_prediction_parameters.size(0) | 
|  | 38 | +        D = normalized_prediction_parameters.size(-1) | 
|  | 39 | +        single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2 | 
|  | 40 | +        logger.trace(f"single_prediction={single_prediction}") | 
|  | 41 | +        if single_prediction:  # add time dimension as it is expected | 
|  | 42 | +            normalized_prediction_parameters = apply_to_list( | 
|  | 43 | +                normalized_prediction_parameters, lambda x: x.unsqueeze(1) | 
|  | 44 | +            ) | 
|  | 45 | +        # transform into real space | 
|  | 46 | +        prediction_parameters = self.transform_output( | 
|  | 47 | +            prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs | 
|  | 48 | +        ) | 
|  | 49 | +        logger.trace( | 
|  | 50 | +            f"prediction_parameters ({len(prediction_parameters)}): {[p.size() for p in prediction_parameters]}" | 
|  | 51 | +        ) | 
|  | 52 | +        # sample value(s) from distribution and  select first sample | 
|  | 53 | +        if isinstance(self.loss, DistributionLoss) or ( | 
|  | 54 | +            isinstance(self.loss, MultiLoss) and isinstance(self.loss[0], DistributionLoss) | 
|  | 55 | +        ): | 
|  | 56 | +            if n_samples > 1: | 
|  | 57 | +                prediction_parameters = apply_to_list( | 
|  | 58 | +                    prediction_parameters, lambda x: x.reshape(int(x.size(0) / n_samples), n_samples, -1) | 
|  | 59 | +                ) | 
|  | 60 | +                prediction = self.loss.sample(prediction_parameters, 1) | 
|  | 61 | +                prediction = apply_to_list(prediction, lambda x: x.reshape(x.size(0) * n_samples, 1, -1)) | 
|  | 62 | +            else: | 
|  | 63 | +                prediction = self.loss.sample(normalized_prediction_parameters, 1) | 
|  | 64 | +        else: | 
|  | 65 | +            prediction = prediction_parameters | 
|  | 66 | +        logger.trace(f"prediction ({len(prediction)}): {[p.size() for p in prediction]}") | 
|  | 67 | +        # normalize prediction prediction | 
|  | 68 | +        normalized_prediction = self.output_transformer.transform(prediction, target_scale=target_scale) | 
|  | 69 | +        if isinstance(normalized_prediction, list): | 
|  | 70 | +            logger.trace(f"normalized_prediction: {[p.size() for p in normalized_prediction]}") | 
|  | 71 | +            input_target = normalized_prediction[-1]  # torch.cat(normalized_prediction, dim=-1)  # dim=-1 | 
|  | 72 | +        else: | 
|  | 73 | +            logger.trace(f"normalized_prediction: {normalized_prediction.size()}") | 
|  | 74 | +            input_target = normalized_prediction  # set next input target to normalized prediction | 
|  | 75 | +        logger.trace(f"input_target: {input_target.size()}") | 
|  | 76 | +        assert input_target.size(0) == B | 
|  | 77 | +        assert input_target.size(-1) == D, f"{input_target.size()} but D={D}" | 
|  | 78 | +        # remove time dimension | 
|  | 79 | +        if single_prediction: | 
|  | 80 | +            prediction = apply_to_list(prediction, lambda x: x.squeeze(1)) | 
|  | 81 | +            input_target = input_target.squeeze(1) | 
|  | 82 | +        logger.trace(f"input_target: {input_target.size()}") | 
|  | 83 | +        return prediction, input_target | 
|  | 84 | + | 
|  | 85 | +    def decode_autoregressive( | 
|  | 86 | +        self, | 
|  | 87 | +        decode_one: Callable, | 
|  | 88 | +        first_target: Union[List[torch.Tensor], torch.Tensor], | 
|  | 89 | +        first_hidden_state: Any, | 
|  | 90 | +        target_scale: Union[List[torch.Tensor], torch.Tensor], | 
|  | 91 | +        n_decoder_steps: int, | 
|  | 92 | +        n_samples: int = 1, | 
|  | 93 | +        **kwargs: Any, | 
|  | 94 | +    ) -> Union[List[torch.Tensor], torch.Tensor]: | 
|  | 95 | +        """ | 
|  | 96 | +        Make predictions in auto-regressive manner. Supports only continuous targets. | 
|  | 97 | +        Args: | 
|  | 98 | +            decode_one (Callable): function that takes at least the following arguments: | 
|  | 99 | +                * ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1) | 
|  | 100 | +                * ``lagged_targets`` (List[torch.Tensor]): list of normalized targets. | 
|  | 101 | +                    List is ``idx + 1`` elements long with the most recent entry at the end, i.e. ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``. | 
|  | 102 | +                * ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater than ``idx`` are included. | 
|  | 103 | +                * additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step. | 
|  | 104 | +            first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding | 
|  | 105 | +            first_hidden_state (Any): first hidden state used for decoding | 
|  | 106 | +            target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x`` | 
|  | 107 | +            n_decoder_steps (int): number of decoding/prediction steps | 
|  | 108 | +            n_samples (int): number of independent samples to draw from the distribution - | 
|  | 109 | +                only relevant for multivariate models. Defaults to 1. | 
|  | 110 | +            **kwargs: additional arguments that are passed to the decode_one function. | 
|  | 111 | +        Returns: | 
|  | 112 | +            Union[List[torch.Tensor], torch.Tensor]: re-scaled prediction | 
|  | 113 | +        """ | 
|  | 114 | +        # make predictions which are fed into next step | 
|  | 115 | +        output: List[Union[List[Tensor], Tensor]] = [] | 
|  | 116 | +        current_hidden_state = first_hidden_state | 
|  | 117 | +        normalized_output = [first_target] | 
|  | 118 | +        for idx in range(n_decoder_steps): | 
|  | 119 | +            # get lagged targets | 
|  | 120 | +            current_target, current_hidden_state = decode_one( | 
|  | 121 | +                idx, lagged_targets=normalized_output, hidden_state=current_hidden_state, **kwargs | 
|  | 122 | +            ) | 
|  | 123 | +            assert isinstance(current_target, Tensor) | 
|  | 124 | +            logger.trace(f"current_target: {current_target.size()}") | 
|  | 125 | +            # get prediction and its normalized version for the next step | 
|  | 126 | +            prediction, current_target = self.output_to_prediction( | 
|  | 127 | +                current_target, target_scale=target_scale, n_samples=n_samples | 
|  | 128 | +            ) | 
|  | 129 | +            logger.trace(f"current_target: {current_target.size()}") | 
|  | 130 | +            if isinstance(prediction, Tensor): | 
|  | 131 | +                logger.trace(f"prediction ({type(prediction)}): {prediction.size()}") | 
|  | 132 | +            else: | 
|  | 133 | +                logger.trace( | 
|  | 134 | +                    f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}" | 
|  | 135 | +                ) | 
|  | 136 | +            # save normalized output for lagged targets | 
|  | 137 | +            normalized_output.append(current_target) | 
|  | 138 | +            # set output to unnormalized samples, append each target as n_batch_samples x n_random_samples | 
|  | 139 | +            output.append(prediction) | 
|  | 140 | +            # Check things before finishing | 
|  | 141 | +            if isinstance(prediction, Tensor): | 
|  | 142 | +                logger.trace(f"output ({len(output)}): {[o.size() for o in output]}")  # type: ignore | 
|  | 143 | +            else: | 
|  | 144 | +                logger.trace(f"output ({len(output)}): {[{len(o)} for o in output]}") | 
|  | 145 | +        if isinstance(self.hparams.target, str): | 
|  | 146 | +            # Here, output is List[Tensor] | 
|  | 147 | +            final_output = torch.stack(output, dim=1)  # type: ignore | 
|  | 148 | +            logger.trace(f"final_output: {final_output.size()}") | 
|  | 149 | +            return final_output | 
|  | 150 | +        # For multi-targets: output is List[List[Tensor]] | 
|  | 151 | +        # final_output_multitarget = [ | 
|  | 152 | +        #     torch.stack([out[idx] for out in output], dim=1) for idx in range(len(self.target_positions)) | 
|  | 153 | +        # ] | 
|  | 154 | +        # self.target_positions is always Tensor([0]), so len() of that is always 1... | 
|  | 155 | +        final_output_multitarget = torch.stack([out[0] for out in output], dim=1) | 
|  | 156 | +        if final_output_multitarget.dim() > 3: | 
|  | 157 | +            final_output_multitarget = final_output_multitarget.squeeze(2) | 
|  | 158 | +        if isinstance(final_output_multitarget, Tensor): | 
|  | 159 | +            logger.trace(f"final_output_multitarget: {final_output_multitarget.size()}") | 
|  | 160 | +        else: | 
|  | 161 | +            logger.trace( | 
|  | 162 | +                f"final_output_multitarget ({type(final_output_multitarget)}): {[o.size() for o in final_output_multitarget]}" | 
|  | 163 | +            ) | 
|  | 164 | +        r = [final_output_multitarget[..., i] for i in range(final_output_multitarget.size(-1))] | 
|  | 165 | +        return r | 
0 commit comments