diff --git a/LICENSE b/LICENSE index 9b176e8eea60b..faaa6cff15b14 100644 --- a/LICENSE +++ b/LICENSE @@ -339,4 +339,14 @@ LMax Disruptor is open source software licensed under the Apache License 2.0 and Project page: https://github.com/LMAX-Exchange/disruptor License: https://github.com/LMAX-Exchange/disruptor/blob/master/LICENCE.txt +-------------------------------------------------------------------------------- + +The following files include code modified from chronos-forecasting project. + +./iotdb-core/ainode/iotdb/ainode/core/model/chronos2/* + +The chronos-forecasting is open source software licensed under the Apache License 2.0 +Project page: https://github.com/amazon-science/chronos-forecasting +License: https://github.com/amazon-science/chronos-forecasting/blob/main/LICENSE + -------------------------------------------------------------------------------- \ No newline at end of file diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 1d21a4d90f017..35fb51598b71b 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -49,7 +49,9 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active"))) + "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/base.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/base.py new file mode 100644 index 0000000000000..17052145857b8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/base.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + import pandas as pd + from transformers import PreTrainedModel + + +from iotdb.ainode.core.model.chronos2.utils import left_pad_and_stack_1D + + +class ForecastType(Enum): + SAMPLES = "samples" + QUANTILES = "quantiles" + + +class PipelineRegistry(type): + REGISTRY: Dict[str, "PipelineRegistry"] = {} + + def __new__(cls, name, bases, attrs): + """See, https://github.com/faif/python-patterns.""" + new_cls = type.__new__(cls, name, bases, attrs) + if name is not None: + cls.REGISTRY[name] = new_cls + + return new_cls + + +class BaseChronosPipeline(metaclass=PipelineRegistry): + forecast_type: ForecastType + dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32} + + def __init__(self, inner_model: "PreTrainedModel"): + """ + Parameters + ---------- + inner_model : PreTrainedModel + A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration + """ + # for easy access to the inner HF-style model + self.inner_model = inner_model + + @property + def model_context_length(self) -> int: + raise NotImplementedError() + + @property + def model_prediction_length(self) -> int: + raise NotImplementedError() + + def _prepare_and_validate_context( + self, context: Union[torch.Tensor, List[torch.Tensor]] + ): + if isinstance(context, list): + context = left_pad_and_stack_1D(context) + assert isinstance(context, torch.Tensor) + if context.ndim == 1: + context = context.unsqueeze(0) + assert context.ndim == 2 + + return context + + def predict( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + prediction_length: Optional[int] = None, + ): + """ + Get forecasts for the given time series. Predictions will be + returned in fp32 on the cpu. + + Parameters + ---------- + inputs + Input series. This is either a 1D tensor, or a list + of 1D tensors, or a 2D tensor whose first dimension + is batch. In the latter case, use left-padding with + ``torch.nan`` to align series of different lengths. + prediction_length + Time steps to predict. Defaults to a model-dependent + value if not given. + + Returns + ------- + forecasts + Tensor containing forecasts. The layout and meaning + of the forecasts values depends on ``self.forecast_type``. + """ + raise NotImplementedError() + + def predict_quantiles( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + prediction_length: Optional[int] = None, + quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get quantile and mean forecasts for given time series. + Predictions will be returned in fp32 on the cpu. + + Parameters + ---------- + inputs : Union[torch.Tensor, List[torch.Tensor]] + Input series. This is either a 1D tensor, or a list + of 1D tensors, or a 2D tensor whose first dimension + is batch. In the latter case, use left-padding with + ``torch.nan`` to align series of different lengths. + prediction_length : Optional[int], optional + Time steps to predict. Defaults to a model-dependent + value if not given. + quantile_levels : List[float], optional + Quantile levels to compute, by default [0.1, 0.2, ..., 0.9] + + Returns + ------- + quantiles + Tensor containing quantile forecasts. Shape + (batch_size, prediction_length, num_quantiles) + mean + Tensor containing mean (point) forecasts. Shape + (batch_size, prediction_length) + """ + raise NotImplementedError() + + def predict_df( + self, + df: "pd.DataFrame", + *, + id_column: str = "item_id", + timestamp_column: str = "timestamp", + target: str = "target", + prediction_length: int | None = None, + quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + validate_inputs: bool = True, + **predict_kwargs, + ) -> "pd.DataFrame": + """ + Perform forecasting on time series data in a long-format pandas DataFrame. + + Parameters + ---------- + df + Time series data in long format with an id column, a timestamp, and one target column. + Any other columns, if present, will be ignored + id_column + The name of the column which contains the unique time series identifiers, by default "item_id" + timestamp_column + The name of the column which contains timestamps, by default "timestamp" + All time series in the dataframe must have regular timestamps with the same frequency (no gaps) + target + The name of the column which contains the target variables to be forecasted, by default "target" + prediction_length + Number of steps to predict for each time series + quantile_levels + Quantile levels to compute + validate_inputs + When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a + regular frequency, and item IDs match between past and future data. Setting to False disables these checks. + **predict_kwargs + Additional arguments passed to predict_quantiles + + Returns + ------- + The forecasts dataframe generated by the model with the following columns + - `id_column`: The time series ID + - `timestamp_column`: Future timestamps + - "target_name": The name of the target column + - "predictions": The point predictions generated by the model + - One column for predictions at each quantile level in `quantile_levels` + """ + try: + import pandas as pd + + from .df_utils import convert_df_input_to_list_of_dicts_input + except ImportError: + raise ImportError( + "pandas is required for predict_df. Please install it with `pip install pandas`." + ) + + if not isinstance(target, str): + raise ValueError( + f"Expected `target` to be str, but found {type(target)}. {self.__class__.__name__} only supports univariate forecasting." + ) + + if prediction_length is None: + prediction_length = self.model_prediction_length + + inputs, original_order, prediction_timestamps = ( + convert_df_input_to_list_of_dicts_input( + df=df, + future_df=None, + id_column=id_column, + timestamp_column=timestamp_column, + target_columns=[target], + prediction_length=prediction_length, + validate_inputs=validate_inputs, + ) + ) + + # NOTE: any covariates, if present, are ignored here + context = [ + torch.tensor(item["target"]).squeeze(0) for item in inputs + ] # squeeze the extra variate dim + + # Generate forecasts + quantiles, mean = self.predict_quantiles( + inputs=context, + prediction_length=prediction_length, + quantile_levels=quantile_levels, + limit_prediction_length=False, + **predict_kwargs, + ) + + quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles] + mean_np = mean.numpy() # [n_series, horizon] + + results_dfs = [] + for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()): + q_pred = quantiles_np[i] # (horizon, num_quantiles) + point_pred = mean_np[i] # (horizon) + + series_forecast_data = { + id_column: series_id, + timestamp_column: future_ts, + "target_name": target, + } + series_forecast_data["predictions"] = point_pred + for q_idx, q_level in enumerate(quantile_levels): + series_forecast_data[str(q_level)] = q_pred[:, q_idx] + + results_dfs.append(pd.DataFrame(series_forecast_data)) + + predictions_df = pd.concat(results_dfs, ignore_index=True) + predictions_df.set_index(id_column, inplace=True) + predictions_df = predictions_df.loc[original_order] + predictions_df.reset_index(inplace=True) + + return predictions_df + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, Path], + *model_args, + force_s3_download=False, + **kwargs, + ): + """ + Load the model, either from a local path, S3 prefix, or from the HuggingFace Hub. + Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``. + """ + + from transformers import AutoConfig + + torch_dtype = kwargs.get("torch_dtype", "auto") + if torch_dtype != "auto" and isinstance(torch_dtype, str): + kwargs["torch_dtype"] = cls.dtypes[torch_dtype] + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr( + config, "chronos_config" + ) + + if not is_valid_config: + raise ValueError("Not a Chronos config file") + + pipeline_class_name = getattr( + config, "chronos_pipeline_class", "ChronosPipeline" + ) + class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name) + if class_ is None: + raise ValueError( + f"Trying to load unknown pipeline class: {pipeline_class_name}" + ) + + return class_.from_pretrained( # type: ignore[attr-defined] + pretrained_model_name_or_path, *model_args, **kwargs + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py new file mode 100644 index 0000000000000..8b221f5f149d7 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py @@ -0,0 +1,703 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import copy +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import AutoConfig +from transformers.models.t5.modeling_t5 import ( + ACT2FN, + T5Config, + T5LayerNorm, + T5PreTrainedModel, + T5Stack, +) +from transformers.utils import ModelOutput + +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.chronos2.base import BaseChronosPipeline, ForecastType + +logger = Logger() + + +@dataclass +class ChronosBoltConfig: + context_length: int + prediction_length: int + input_patch_size: int + input_patch_stride: int + quantiles: List[float] + use_reg_token: bool = False + + +@dataclass +class ChronosBoltOutput(ModelOutput): + loss: Optional[torch.Tensor] = None + quantile_preds: Optional[torch.Tensor] = None + attentions: Optional[torch.Tensor] = None + cross_attentions: Optional[torch.Tensor] = None + + +class Patch(nn.Module): + def __init__(self, patch_size: int, patch_stride: int) -> None: + super().__init__() + self.patch_size = patch_size + self.patch_stride = patch_stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + length = x.shape[-1] + + if length % self.patch_size != 0: + padding_size = ( + *x.shape[:-1], + self.patch_size - (length % self.patch_size), + ) + padding = torch.full( + size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device + ) + x = torch.concat((padding, x), dim=-1) + + x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) + return x + + +class InstanceNorm(nn.Module): + """ + Apply standardization along the last dimension and optionally apply arcsinh after standardization. + """ + + def __init__(self, eps: float = 1e-5, use_arcsinh: bool = False) -> None: + super().__init__() + self.eps = eps + self.use_arcsinh = use_arcsinh + + def forward( + self, + x: torch.Tensor, + loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if loc_scale is None: + loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0) + scale = torch.nan_to_num( + (x - loc).square().nanmean(dim=-1, keepdim=True).sqrt(), nan=1.0 + ) + scale = torch.where(scale == 0, self.eps, scale) + else: + loc, scale = loc_scale + + scaled_x = (x - loc) / scale + + if self.use_arcsinh: + scaled_x = torch.arcsinh(scaled_x) + + return scaled_x.to(orig_dtype), (loc, scale) + + def inverse( + self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + loc, scale = loc_scale + + if self.use_arcsinh: + x = torch.sinh(x) + + x = x * scale + loc + + return x.to(orig_dtype) + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_dim: int, + h_dim: int, + out_dim: int, + act_fn_name: str, + dropout_p: float = 0.0, + use_layer_norm: bool = False, + ) -> None: + super().__init__() + + self.dropout = nn.Dropout(dropout_p) + self.hidden_layer = nn.Linear(in_dim, h_dim) + self.act = ACT2FN[act_fn_name] + self.output_layer = nn.Linear(h_dim, out_dim) + self.residual_layer = nn.Linear(in_dim, out_dim) + + self.use_layer_norm = use_layer_norm + if use_layer_norm: + self.layer_norm = T5LayerNorm(out_dim) + + def forward(self, x: torch.Tensor): + hid = self.act(self.hidden_layer(x)) + out = self.dropout(self.output_layer(hid)) + res = self.residual_layer(x) + + out = out + res + + if self.use_layer_norm: + return self.layer_norm(out) + return out + + +class ChronosBoltModelForForecasting(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ # type: ignore + r"input_patch_embedding\.", + r"output_patch_embedding\.", + ] + _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore + + def __init__(self, config: T5Config): + assert hasattr(config, "chronos_config"), "Not a Chronos config file" + + super().__init__(config) + self.model_dim = config.d_model + + self.chronos_config = ChronosBoltConfig(**config.chronos_config) + + # Only decoder_start_id (and optionally REG token) + if self.chronos_config.use_reg_token: + config.reg_token_id = 1 + + config.vocab_size = 2 if self.chronos_config.use_reg_token else 1 + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + # Input patch embedding layer + self.input_patch_embedding = ResidualBlock( + in_dim=self.chronos_config.input_patch_size * 2, + h_dim=config.d_ff, + out_dim=config.d_model, + act_fn_name=config.dense_act_fn, + dropout_p=config.dropout_rate, + ) + + # patching layer + self.patch = Patch( + patch_size=self.chronos_config.input_patch_size, + patch_stride=self.chronos_config.input_patch_stride, + ) + + # instance normalization, also referred to as "scaling" in Chronos and GluonTS + self.instance_norm = InstanceNorm() + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + self.encoder = T5Stack(encoder_config, self.shared) + + self._init_decoder(config) + + self.num_quantiles = len(self.chronos_config.quantiles) + quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype) + self.quantiles: torch.Tensor + self.register_buffer("quantiles", quantiles, persistent=False) + + self.output_patch_embedding = ResidualBlock( + in_dim=config.d_model, + h_dim=config.d_ff, + out_dim=self.num_quantiles * self.chronos_config.prediction_length, + act_fn_name=config.dense_act_fn, + dropout_p=config.dropout_rate, + ) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def _init_weights(self, module): + super()._init_weights(module) + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, (self.__class__)): + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, ResidualBlock): + module.hidden_layer.weight.data.normal_( + mean=0.0, + std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5), + ) + if ( + hasattr(module.hidden_layer, "bias") + and module.hidden_layer.bias is not None + ): + module.hidden_layer.bias.data.zero_() + + module.residual_layer.weight.data.normal_( + mean=0.0, + std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5), + ) + if ( + hasattr(module.residual_layer, "bias") + and module.residual_layer.bias is not None + ): + module.residual_layer.bias.data.zero_() + + module.output_layer.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if ( + hasattr(module.output_layer, "bias") + and module.output_layer.bias is not None + ): + module.output_layer.bias.data.zero_() + + def encode( + self, context: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> Tuple[ + torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor + ]: + mask = ( + mask.to(context.dtype) + if mask is not None + else torch.isnan(context).logical_not().to(context.dtype) + ) + + batch_size, _ = context.shape + if context.shape[-1] > self.chronos_config.context_length: + context = context[..., -self.chronos_config.context_length :] + mask = mask[..., -self.chronos_config.context_length :] + + # scaling + context, loc_scale = self.instance_norm(context) + + # the scaling op above is done in 32-bit precision, + # then the context is moved to model's dtype + context = context.to(self.dtype) + mask = mask.to(self.dtype) + + # patching + patched_context = self.patch(context) + patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0) + patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0) + # concat context and mask along patch dim + patched_context = torch.cat([patched_context, patched_mask], dim=-1) + + # attention_mask = 1 if at least one item in the patch is observed + attention_mask = ( + patched_mask.sum(dim=-1) > 0 + ) # (batch_size, patched_seq_length) + + input_embeds = self.input_patch_embedding(patched_context) + + if self.chronos_config.use_reg_token: + # Append [REG] + reg_input_ids = torch.full( + (batch_size, 1), + self.config.reg_token_id, + device=input_embeds.device, + ) + reg_embeds = self.shared(reg_input_ids) + input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2) + attention_mask = torch.cat( + [ + attention_mask.to(self.dtype), + torch.ones_like(reg_input_ids).to(self.dtype), + ], + dim=-1, + ) + + encoder_outputs = self.encoder( + attention_mask=attention_mask, + inputs_embeds=input_embeds, + ) + + return encoder_outputs[0], loc_scale, input_embeds, attention_mask + + def forward( + self, + context: torch.Tensor, + mask: Optional[torch.Tensor] = None, + target: Optional[torch.Tensor] = None, + target_mask: Optional[torch.Tensor] = None, + ) -> ChronosBoltOutput: + batch_size = context.size(0) + + hidden_states, loc_scale, input_embeds, attention_mask = self.encode( + context=context, mask=mask + ) + sequence_output = self.decode(input_embeds, attention_mask, hidden_states) + + quantile_preds_shape = ( + batch_size, + self.num_quantiles, + self.chronos_config.prediction_length, + ) + quantile_preds = self.output_patch_embedding(sequence_output).view( + *quantile_preds_shape + ) + + loss = None + if target is not None: + # normalize target + target, _ = self.instance_norm(target, loc_scale) + target = target.unsqueeze(1) # type: ignore + assert self.chronos_config.prediction_length >= target.shape[-1] + + target = target.to(quantile_preds.device) + target_mask = ( + target_mask.unsqueeze(1).to(quantile_preds.device) + if target_mask is not None + else ~torch.isnan(target) + ) + target[~target_mask] = 0.0 + + # pad target and target_mask if they are shorter than model's prediction_length + if self.chronos_config.prediction_length > target.shape[-1]: + padding_shape = ( + *target.shape[:-1], + self.chronos_config.prediction_length - target.shape[-1], + ) + target = torch.cat( + [target, torch.zeros(padding_shape).to(target)], dim=-1 + ) + target_mask = torch.cat( + [target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1 + ) + + loss = ( + 2 + * torch.abs( + (target - quantile_preds) + * ( + (target <= quantile_preds).float() + - self.quantiles.view(1, self.num_quantiles, 1) + ) + ) + * target_mask.float() + ) + loss = loss.mean(dim=-2) # Mean over prediction horizon + loss = loss.sum(dim=-1) # Sum over quantile levels + loss = loss.mean() # Mean over batch + + # Unscale predictions + quantile_preds = self.instance_norm.inverse( + quantile_preds.view(batch_size, -1), + loc_scale, + ).view(*quantile_preds_shape) + + return ChronosBoltOutput( + loss=loss, + quantile_preds=quantile_preds, + ) + + def _init_decoder(self, config): + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + def decode( + self, + input_embeds, + attention_mask, + hidden_states, + output_attentions=False, + ): + """ + Parameters + ---------- + input_embeds: torch.Tensor + Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model) + attention_mask: torch.Tensor + Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64 + hidden_states: torch.Tensor + Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model) + + Returns + ------- + last_hidden_state + Last hidden state returned by the decoder, of shape (batch_size, 1, d_model) + """ + batch_size = input_embeds.shape[0] + decoder_input_ids = torch.full( + (batch_size, 1), + self.config.decoder_start_token_id, + device=input_embeds.device, + ) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + return_dict=True, + ) + + return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model + + +class ChronosBoltPipeline(BaseChronosPipeline): + forecast_type: ForecastType = ForecastType.QUANTILES + default_context_length: int = 2048 + + def __init__(self, model: ChronosBoltModelForForecasting): + super().__init__(inner_model=model) # type: ignore + self.model = model + + @property + def model_context_length(self) -> int: + return self.model.chronos_config.context_length + + @property + def model_prediction_length(self) -> int: + return self.model.chronos_config.prediction_length + + @property + def quantiles(self) -> List[float]: + return self.model.config.chronos_config["quantiles"] + + @torch.no_grad() + def embed( + self, context: Union[torch.Tensor, List[torch.Tensor]] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Get encoder embeddings for the given time series. + + Parameters + ---------- + context + Input series. This is either a 1D tensor, or a list + of 1D tensors, or a 2D tensor whose first dimension + is batch. In the latter case, use left-padding with + ``torch.nan`` to align series of different lengths. + + Returns + ------- + embeddings, loc_scale + A tuple of two items: the encoder embeddings and the loc_scale, + i.e., the mean and std of the original time series. + The encoder embeddings are shaped (batch_size, num_patches + 1, d_model), + where num_patches is the number of patches in the time series + and the extra 1 is for the [REG] token (if used by the model). + """ + context_tensor = self._prepare_and_validate_context(context=context) + model_context_length = self.model.config.chronos_config["context_length"] + + if context_tensor.shape[-1] > model_context_length: + context_tensor = context_tensor[..., -model_context_length:] + + context_tensor = context_tensor.to( + device=self.model.device, + dtype=torch.float32, + ) + embeddings, loc_scale, *_ = self.model.encode(context=context_tensor) + return embeddings.cpu(), ( + loc_scale[0].squeeze(-1).cpu(), + loc_scale[1].squeeze(-1).cpu(), + ) + + def predict( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + prediction_length: Optional[int] = None, + limit_prediction_length: bool = False, + ) -> torch.Tensor: + """ + Get forecasts for the given time series. + + Refer to the base method (``BaseChronosPipeline.predict``) + for details on shared parameters. + Additional parameters + --------------------- + limit_prediction_length + Force prediction length smaller or equal than the + built-in prediction length from the model. False by + default. When true, fail loudly if longer predictions + are requested, otherwise longer predictions are allowed. + + Returns + ------- + torch.Tensor + Forecasts of shape (batch_size, num_quantiles, prediction_length) + where num_quantiles is the number of quantiles the model has been + trained to output. For official Chronos-Bolt models, the value of + num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles. + + Raises + ------ + ValueError + When limit_prediction_length is True and the prediction_length is + greater than model's training prediction_length. + """ + context_tensor = self._prepare_and_validate_context(context=inputs) + + if prediction_length is None: + prediction_length = self.model_prediction_length + + if prediction_length > self.model_prediction_length: + msg = ( + f"We recommend keeping prediction length <= {self.model_prediction_length}. " + "The quality of longer predictions may degrade since the model is not optimized for it. " + ) + if limit_prediction_length: + msg += "You can turn off this check by setting `limit_prediction_length=False`." + raise ValueError(msg) + logger.warning(msg) + + predictions = [] + remaining = prediction_length + + # We truncate the context here because otherwise batches with very long + # context could take up large amounts of GPU memory unnecessarily. + if context_tensor.shape[-1] > self.model_context_length: + context_tensor = context_tensor[..., -self.model_context_length :] + + context_tensor = context_tensor.to( + device=self.model.device, dtype=torch.float32 + ) + # First block prediction + with torch.no_grad(): + prediction: torch.Tensor = self.model( + context=context_tensor + ).quantile_preds.to(context_tensor) + + predictions.append(prediction) + remaining -= prediction.shape[-1] + + # NOTE: The following heuristic for better prediction intervals with long-horizon forecasts + # uses all quantiles generated by the model for the first `model_prediction_length` steps, + # concatenating each quantile with the context and generating the next `model_prediction_length` steps. + # The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles` + # by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles` + # when the `prediction_length` is greater than `model_prediction_length`. + + if remaining > 0: + # Expand the context along quantile axis + context_tensor = context_tensor.unsqueeze(1).repeat( + 1, len(self.quantiles), 1 + ) + + quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device) + while remaining > 0: + # Append the prediction to context + context_tensor = torch.cat([context_tensor, prediction], dim=-1)[ + ..., -self.model_context_length : + ] + (batch_size, n_quantiles, context_length) = context_tensor.shape + + with torch.no_grad(): + # Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length) + prediction = self.model( + context=context_tensor.reshape( + batch_size * n_quantiles, context_length + ) + ).quantile_preds.to(context_tensor) + # Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length) + prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1) + # Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length) + prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose( + 0, 1 + ) + + predictions.append(prediction) + remaining -= prediction.shape[-1] + + return torch.cat(predictions, dim=-1)[..., :prediction_length].to( + dtype=torch.float32, device="cpu" + ) + + def predict_quantiles( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + prediction_length: Optional[int] = None, + quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + **predict_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Refer to the base method (``BaseChronosPipeline.predict_quantiles``). + """ + # shape (batch_size, prediction_length, len(training_quantile_levels)) + predictions = ( + self.predict(inputs, prediction_length=prediction_length, **predict_kwargs) + .detach() + .swapaxes(1, 2) + ) + + training_quantile_levels = self.quantiles + + if set(quantile_levels).issubset(set(training_quantile_levels)): + # no need to perform intra/extrapolation + quantiles = predictions[ + ..., [training_quantile_levels.index(q) for q in quantile_levels] + ] + else: + # we rely on torch for interpolating quantiles if quantiles that + # Chronos Bolt was trained on were not provided + if min(quantile_levels) < min(training_quantile_levels) or max( + quantile_levels + ) > max(training_quantile_levels): + logger.warning( + f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of " + f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). " + "Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt " + "was trained on. This may significantly affect the quality of the predictions." + ) + + # TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels) + # made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9). + # While this holds for official Chronos-Bolt models, this may not be true in the future, and this + # function may have to be revised. + augmented_predictions = torch.cat( + [predictions[..., [0]], predictions, predictions[..., [-1]]], + dim=-1, + ) + quantiles = torch.quantile( + augmented_predictions, + q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype), + dim=-1, + ).permute(1, 2, 0) + # NOTE: the median is returned as the mean here + mean = predictions[:, :, training_quantile_levels.index(0.5)] + return quantiles, mean + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + """ + Load the model, either from a local path S3 prefix or from the HuggingFace Hub. + Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``. + """ + + if str(pretrained_model_name_or_path).startswith("s3://"): + return BaseChronosPipeline.from_pretrained( + pretrained_model_name_or_path, *args, **kwargs + ) + + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, *args, **kwargs + ) + assert hasattr(config, "chronos_config"), "Not a Chronos config file" + + architecture = config.architectures[0] + class_ = globals().get(architecture) + + if class_ is None: + logger.warning( + f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting" + ) + class_ = ChronosBoltModelForForecasting + + model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + return cls(model=model) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/config.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/config.py new file mode 100644 index 0000000000000..226e1ccfdf825 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/config.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from dataclasses import dataclass +from typing import List, Literal + +from transformers.configuration_utils import PretrainedConfig + + +class Chronos2CoreConfig(PretrainedConfig): + """ + HF transformers-style pretrained model config for Chronos-2.0, based on T5Config. + + Arguments + ---------- + d_model + Size model's hidden states, by default 512 + d_kv + Size of the key, query, value projections per attention head, by default 64 + d_ff + Size of the intermediate feed forward layers, by default 2048 + num_layers + Number of hidden layers in the encoder, by default 6 + num_heads + Number of attention heads for each attention layer, by default 8 + dropout_rate + The ratio for all dropout layers, by default 0.1 + layer_norm_epsilon + The epsilon used by the layer normalization layers, by default 1e-6 + initializer_factor + A factor for initializing all weight matrices, by default 0.05 + feed_forward_proj + Type of feed forward layer to be used, by default "relu" + vocab_size + Size of vocabulary for special tokens, by default 2 + pad_token_id + Token ID for padding/missing value token, by default 0 + rope_theta + The base theta for rotary position embedding (RoPE), by default 10000.0 + attn_implementation + The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa") + """ + + model_type = "t5" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } + + def __init__( + self, + d_model: int = 512, + d_kv: int = 64, + d_ff: int = 2048, + num_layers: int = 6, + num_heads: int = 8, + dropout_rate: float = 0.1, + layer_norm_epsilon: float = 1e-6, + initializer_factor: float = 0.05, + feed_forward_proj: str = "relu", + vocab_size: int = 2, + pad_token_id: int = 0, + rope_theta: float = 10000.0, + attn_implementation: Literal["eager", "sdpa"] | None = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.rope_theta = rope_theta + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + assert not self.is_gated_act, "gated activation is not supported" + + # Attention implementation - default to "sdpa" if not specified + attn_implementation = attn_implementation or "sdpa" + assert attn_implementation in [ + "eager", + "sdpa", + ], f"attn_implementation {attn_implementation} not supported" + + # unused + kwargs.pop("is_encoder_decoder", None) + kwargs.pop("eos_token_id", None) + + super().__init__( + pad_token_id=pad_token_id, + is_encoder_decoder=False, + attn_implementation=attn_implementation, + **kwargs, + ) + + +@dataclass +class Chronos2ForecastingConfig: + context_length: int + output_patch_size: int + input_patch_size: int + input_patch_stride: int + quantiles: List[float] + use_reg_token: bool = False + use_arcsinh: bool = False + max_output_patches: int = 1 + time_encoding_scale: int | None = None + + @classmethod + def editable_fields(cls) -> list[str]: + """ + Fields that maybe modified during the fine-tuning stage. + """ + return ["context_length", "max_output_patches"] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/dataset.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/dataset.py new file mode 100644 index 0000000000000..cf3b5edbff34d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/dataset.py @@ -0,0 +1,756 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import math +from enum import Enum +from typing import TYPE_CHECKING, Iterator, Mapping, Sequence, TypeAlias, cast + +import numpy as np +import torch +from sklearn.preprocessing import OrdinalEncoder, TargetEncoder +from torch.utils.data import IterableDataset + +if TYPE_CHECKING: + import datasets + import fev + + +TensorOrArray: TypeAlias = torch.Tensor | np.ndarray + + +def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor: + """ + Left pads tensors in the list to the length of the longest tensor along the second axis, then concats + these equal length tensors along the first axis. + """ + max_len = max(tensor.shape[-1] for tensor in tensors) + padded = [] + for tensor in tensors: + n_variates, length = tensor.shape + if length < max_len: + padding = torch.full( + (n_variates, max_len - length), + fill_value=torch.nan, + device=tensor.device, + ) + tensor = torch.cat([padding, tensor], dim=-1) + padded.append(tensor) + + return torch.cat(padded, dim=0) + + +def validate_and_prepare_single_dict_task( + task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], + idx: int, + prediction_length: int, +) -> tuple[torch.Tensor, torch.Tensor, int, int, int]: + """Validates and prepares a single dictionary task for Chronos2Model. + + Parameters + ---------- + task + A dictionary representing a time series that contains: + - `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length). + Forecasts will be generated for items in `target`. + - `past_covariates` (optional): a dict of past-only covariates or past values of known future covariates. The keys of the dict + must be names of the covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `history_length` + of `target`. + - `future_covariates` (optional): a dict of future values of known future covariates. The keys of the dict must be names of the + covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in + `future_covariates` must be a subset of the keys in `past_covariates`. + idx + Index of this task in the list of tasks, used for error messages + prediction_length + Number of future time steps to predict, used to validate future covariates + + Returns + ------ + A tuple containing: + - task_context_tensor: Concatenated tensor of target and past covariates of shape (group_size, history_length), + the first `task_n_targets` items along the first axis contain the target variables and the remaining items contain past-only covariates + and past values of known future covariates. + - task_future_covariates_tensor: Tensor of future covariates of shape (group_size, prediction_length). The last `task_n_future_covariates` + items along the first axis contain future covariates. All the remaining elements corresponding to target and past-only covariates are NaNs. + - task_n_targets: Number of target variables + - task_n_covariates: Total number of covariates (sum of past-only and known future covariates) + - task_n_future_covariates: Number of known future covariates + """ + + allowed_keys = {"target", "past_covariates", "future_covariates"} + + # validate keys + keys = set(task.keys()) + if not keys.issubset(allowed_keys): + raise ValueError( + f"Found invalid keys in element at index {idx}. Allowed keys are {allowed_keys}, but found {keys}" + ) + if "target" not in keys: + raise ValueError( + f"Element at index {idx} does not contain the required key 'target'" + ) + + # validate target + task_target = task["target"] + if isinstance(task_target, np.ndarray): + task_target = torch.from_numpy(task_target) + assert isinstance(task_target, torch.Tensor) + if task_target.ndim > 2: + raise ValueError( + "When the input is a list of dicts, the `target` should either be 1-d with shape (history_length,) " + f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(task_target.shape)}." + ) + history_length = task_target.shape[-1] + task_target = task_target.view(-1, history_length) + + # validate past_covariates + cat_encoders: dict = {} + task_past_covariates = task.get("past_covariates", {}) + if not isinstance(task_past_covariates, dict): + raise ValueError( + f"Found invalid type for `past_covariates` in element at index {idx}. " + f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}' + ) + + # gather keys and ensure known-future keys come last to match downstream assumptions + task_covariates_keys = sorted(task_past_covariates.keys()) + + task_future_covariates = task.get("future_covariates", {}) + if not isinstance(task_future_covariates, dict): + raise ValueError( + f"Found invalid type for `future_covariates` in element at index {idx}. " + f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}' + ) + task_future_covariates_keys = sorted(task_future_covariates.keys()) + if not set(task_future_covariates_keys).issubset(task_covariates_keys): + raise ValueError( + f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, " + f"but found {task_future_covariates_keys} in element at index {idx}" + ) + + # create ordered keys: past-only first, then known-future (so known-future are the last rows) + task_past_only_keys = [ + k for k in task_covariates_keys if k not in task_future_covariates_keys + ] # past_only_keys + task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys + + task_past_covariates_list: list[torch.Tensor] = [] + for key in task_ordered_covariate_keys: + tensor = task_past_covariates[key] + if isinstance(tensor, np.ndarray): + # apply encoding to categorical variates + if not np.issubdtype(tensor.dtype, np.number): + # target encoding, if the target is 1-d + if task_target.shape[0] == 1: + cat_encoder = TargetEncoder(target_type="continuous", smooth=1.0) + X = tensor.astype(str).reshape(-1, 1) + y = task_target.view(-1).numpy() + mask = np.isfinite(y) + X = X[mask] + y = y[mask] + cat_encoder.fit(X, y) + # ordinal encoding, if the target is > 1-d + else: + cat_encoder = OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=np.nan + ) + cat_encoder.fit(tensor.astype(str).reshape(-1, 1)) + tensor = cat_encoder.transform( + tensor.astype(str).reshape(-1, 1) + ).reshape(tensor.shape) + cat_encoders[key] = cat_encoder + tensor = torch.from_numpy(tensor) + assert isinstance(tensor, torch.Tensor) + if tensor.ndim != 1 or len(tensor) != history_length: + raise ValueError( + f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {history_length}), " + f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}" + ) + task_past_covariates_list.append(tensor) + task_past_covariates_tensor = ( + torch.stack(task_past_covariates_list, dim=0) + if task_past_covariates_list + else torch.zeros((0, history_length), device=task_target.device) + ) + + # validate future_covariates (build rows in the same task_ordered_covariate_keys order) + task_future_covariates_list: list[torch.Tensor] = [] + for key in task_ordered_covariate_keys: + # future values of past-only covariates are filled with NaNs + tensor = task_future_covariates.get( + key, torch.full((prediction_length,), fill_value=torch.nan) + ) + if isinstance(tensor, np.ndarray): + # apply encoding to categorical variates + if not np.issubdtype(tensor.dtype, np.number): + cat_encoder = cat_encoders[key] + tensor = cat_encoder.transform( + tensor.astype(str).reshape(-1, 1) + ).reshape(tensor.shape) + tensor = torch.from_numpy(tensor) + assert isinstance(tensor, torch.Tensor) + if tensor.ndim != 1 or len(tensor) != prediction_length: + raise ValueError( + f"Individual `future_covariates` must be 1-d with length equal to the {prediction_length=}, " + f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}" + ) + task_future_covariates_list.append(tensor) + task_future_covariates_tensor = ( + torch.stack(task_future_covariates_list, dim=0) + if task_future_covariates_list + else torch.zeros((0, prediction_length), device=task_target.device) + ) + # future values of target series are filled with NaNs + task_future_covariates_target_padding = torch.full( + (task_target.shape[0], prediction_length), + fill_value=torch.nan, + device=task_target.device, + ) + + task_context_tensor = torch.cat( + [task_target, task_past_covariates_tensor], dim=0 + ).to(dtype=torch.float32) + task_future_covariates_tensor = torch.cat( + [task_future_covariates_target_padding, task_future_covariates_tensor], dim=0 + ).to(dtype=torch.float32) + task_n_targets = task_target.shape[0] + task_n_covariates = task_past_covariates_tensor.shape[0] + # number of known-future covariates + task_n_future_covariates = len(task_future_covariates_keys) + + return ( + task_context_tensor, + task_future_covariates_tensor, + task_n_targets, + task_n_covariates, + task_n_future_covariates, + ) + + +def convert_list_of_tensors_input_to_list_of_dicts_input( + list_of_tensors: Sequence[TensorOrArray], +) -> list[dict[str, torch.Tensor]]: + """Convert a list of tensors input format to a list of dictionaries input format. + + + Parameters + ---------- + list_of_tensors + A sequence of tensors or numpy arrays, where each element represents a time series. + Each element should be either 1-d with shape (history_length,) or 2-d with shape + (n_variates, history_length). + + Returns + ------- + A list of dictionaries, where each dictionary represents a time series and contains: + - `target`: a 1-d or 2-d torch.Tensor of shape (history_length,) or (n_variates, history_length). + """ + + output: list[dict[str, torch.Tensor]] = [] + for idx, tensor in enumerate(list_of_tensors): + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor) + if tensor.ndim > 2: + raise ValueError( + "When the input is a list of torch tensors or numpy arrays, the elements should either be 1-d with shape (history_length,) " + f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(tensor.shape)}." + ) + length = tensor.shape[-1] + tensor = tensor.view(-1, length) + + output.append({"target": tensor}) + + return output + + +def convert_tensor_input_to_list_of_dicts_input( + tensor: TensorOrArray, +) -> list[dict[str, torch.Tensor]]: + """ + Convert a tensor input format to a list of dictionaries input format. + + Parameters + ---------- + tensor + A tensor or numpy array representing multiple time series. + Should be 3-d with shape (n_series, n_variates, history_length). + + Returns + ------- + A list of dictionaries, where each dictionary represents a time series and contains: + - `target`: a 2-d torch.Tensor of shape (n_variates, history_length). + """ + + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor) + if tensor.ndim != 3: + raise ValueError( + "When the input is a torch tensor or numpy array, it should be 3-d with shape (n_series, n_variates, history_length). " + f" Found shape: {tuple(tensor.shape)}." + ) + + output: list[dict[str, torch.Tensor]] = [] + n_series = len(tensor) + for i in range(n_series): + output.append({"target": tensor[i]}) + + return output + + +def _cast_fev_features( + past_data: "datasets.Dataset", + future_data: "datasets.Dataset", + target_columns: list[str], + past_dynamic_columns: list[str], + known_dynamic_columns: list[str], +) -> tuple["datasets.Dataset", "datasets.Dataset"]: + import datasets + + dynamic_columns = [*past_dynamic_columns, *known_dynamic_columns] + cat_cols = [] + for col in dynamic_columns: + item = past_data[0][col] + if not np.issubdtype(item.dtype, np.number): + cat_cols.append(col) + + numeric_cols = target_columns + list(set(dynamic_columns) - set(cat_cols)) + past_feature_updates = { + col: datasets.Sequence(datasets.Value("float64")) for col in numeric_cols + } | {col: datasets.Sequence(datasets.Value("string")) for col in cat_cols} + past_data_features = past_data.features + past_data_features.update(past_feature_updates) + past_data = past_data.cast(past_data_features) + + future_cat_cols = [k for k in cat_cols if k in known_dynamic_columns] + future_numeric_cols = list(set(known_dynamic_columns) - set(future_cat_cols)) + future_feature_updates = { + col: datasets.Sequence(datasets.Value("float64")) for col in future_numeric_cols + } | {col: datasets.Sequence(datasets.Value("string")) for col in future_cat_cols} + future_data_features = future_data.features + future_data_features.update(future_feature_updates) + future_data = future_data.cast(future_data_features) + + return past_data, future_data + + +def convert_fev_window_to_list_of_dicts_input( + window: "fev.EvaluationWindow", as_univariate: bool +) -> tuple[ + list[dict[str, np.ndarray | dict[str, np.ndarray]]], list[str], list[str], list[str] +]: + import fev + + if as_univariate: + past_data, future_data = fev.convert_input_data( + window, adapter="datasets", as_univariate=True + ) + target_columns = ["target"] + past_dynamic_columns = [] + known_dynamic_columns = [] + else: + past_data, future_data = window.get_input_data() + target_columns = window.target_columns + past_dynamic_columns = window.past_dynamic_columns + known_dynamic_columns = window.known_dynamic_columns + + past_data, future_data = _cast_fev_features( + past_data=past_data, + future_data=future_data, + target_columns=target_columns, + past_dynamic_columns=past_dynamic_columns, + known_dynamic_columns=known_dynamic_columns, + ) + + num_series: int = len(past_data) + num_past_covariates: int = len(past_dynamic_columns) + num_future_covariates: int = len(known_dynamic_columns) + + # We use numpy format because torch does not support str covariates + target_data = past_data.select_columns(target_columns).with_format("numpy") + # past of past-only and known-future covariates + dynamic_columns = [*past_dynamic_columns, *known_dynamic_columns] + past_covariate_data = past_data.select_columns(dynamic_columns).with_format("numpy") + future_known_data = future_data.select_columns(known_dynamic_columns).with_format( + "numpy" + ) + + if num_past_covariates + num_future_covariates > 0: + assert len(past_covariate_data) == num_series + if num_future_covariates > 0: + assert len(future_known_data) == num_series + + inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = [] + for idx, target_row in enumerate(target_data): + target_row = cast(dict, target_row) + # this assumes that the targets have the same length for multivariate tasks + target_tensor_i = np.stack([target_row[col] for col in target_columns]) + entry: dict[str, np.ndarray | dict[str, np.ndarray]] = { + "target": target_tensor_i + } + + if len(dynamic_columns) > 0: + past_covariate_row = past_covariate_data[idx] + entry["past_covariates"] = { + col: past_covariate_row[col] for col in dynamic_columns + } + + if len(known_dynamic_columns) > 0: + future_known_row = future_known_data[idx] + entry["future_covariates"] = { + col: future_known_row[col] for col in known_dynamic_columns + } + + inputs.append(entry) + + return inputs, target_columns, past_dynamic_columns, known_dynamic_columns + + +class DatasetMode(str, Enum): + TRAIN = "train" + VALIDATION = "validation" + TEST = "test" + + +class Chronos2Dataset(IterableDataset): + """ + A dataset wrapper for Chronos-2 models. + + Arguments + ---------- + inputs + Time series data. Must be a list of dictionaries where each dictionary may have the following keys. + - `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length). + Forecasts will be generated for items in `target`. + - `past_covariates` (optional): a dict of past-only covariates or past values of known future covariates. The keys of the dict + must be names of the covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `history_length` + of `target`. + - `future_covariates` (optional): a dict of future values of known future covariates. The keys of the dict must be names of the + covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in + `future_covariates` must be a subset of the keys in `past_covariates`. + Note: when the mode is set to TRAIN, the values inside `future_covariates` are not technically used for training the model; + however, this key is used to infer which covariates are known into the future. Therefore, if your task contains known future covariates, + make sure that this key exists in `inputs`. The values of individual future covariates may be set to `None` or an empty array. + context_length + The maximum context length used for training or inference + prediction_length + The prediction horizon + batch_size + The batch size for training the model. Note that the batch size here means the number of time series, including target(s) and + covariates, that are input into the model. If your data has multiple target and/or covariates, the effective number of time series + tasks in a batch will be lower than this value. + output_patch_size + The output patch size of the model. This is used to compute the number of patches needed to cover `prediction_length` + min_past + The minimum number of time steps the context must have during training. All time series shorter than `min_past + prediction_length` + are filtered out, by default 1 + mode + `DatasetMode` governing whether to generate training, validation or test samples, by default "train" + """ + + def __init__( + self, + inputs: Sequence[ + Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]] + ], + context_length: int, + prediction_length: int, + batch_size: int, + output_patch_size: int, + min_past: int = 1, + mode: str | DatasetMode = DatasetMode.TRAIN, + ) -> None: + super().__init__() + assert mode in { + DatasetMode.TRAIN, + DatasetMode.VALIDATION, + DatasetMode.TEST, + }, f"Invalid mode: {mode}" + + self.tasks = Chronos2Dataset._prepare_tasks( + inputs, prediction_length, min_past, mode + ) + self.context_length = context_length + self.prediction_length = prediction_length + self.batch_size = batch_size + self.num_output_patches = math.ceil(prediction_length / output_patch_size) + self.min_past = min_past + self.mode = mode + + @staticmethod + def _prepare_tasks( + inputs: Sequence[ + Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]] + ], + prediction_length: int, + min_past: int, + mode: str | DatasetMode, + ): + tasks = [] + for idx, raw_task in enumerate(inputs): + if mode != DatasetMode.TEST: + raw_future_covariates = raw_task.get("future_covariates", {}) + raw_future_covariates = cast( + dict[str, TensorOrArray | None], raw_future_covariates + ) + if raw_future_covariates: + fixed_future_covariates = {} + for key, value in raw_future_covariates.items(): + fixed_future_covariates[key] = ( + np.full(prediction_length, np.nan) + if value is None or len(value) == 0 + else value + ) + raw_task = { + **raw_task, + "future_covariates": fixed_future_covariates, + } + + raw_task = cast( + dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task + ) + # convert to a format compatible with model's forward + task = validate_and_prepare_single_dict_task( + raw_task, idx, prediction_length + ) + + if ( + mode != DatasetMode.TEST + and task[0].shape[-1] < min_past + prediction_length + ): + # filter tasks based on min_past + prediction_length + continue + tasks.append(task) + + if len(tasks) == 0: + raise ValueError( + "The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). " + "Please provide longer time series or reduce `min_past` or `prediction_length`. " + ) + return tasks + + def _construct_slice( + self, task_idx: int + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: + ( + task_past_tensor, # shape: (task_n_targets + task_n_covariates, history_length) + task_future_tensor, + task_n_targets, + task_n_covariates, + task_n_future_covariates, + ) = self.tasks[task_idx] + task_past_tensor, task_future_tensor = ( + task_past_tensor.clone(), + task_future_tensor.clone(), + ) + task_n_past_only_covariates = task_n_covariates - task_n_future_covariates + + full_length = task_past_tensor.shape[-1] + + if self.mode == DatasetMode.TRAIN: + # slice a random subsequence from the full series + slice_idx = np.random.randint( + self.min_past, full_length - self.prediction_length + 1 + ) + elif self.mode == DatasetMode.VALIDATION: + # slice the last window for validation + slice_idx = full_length - self.prediction_length + else: + # slice the full series for prediction + slice_idx = full_length + + if slice_idx >= self.context_length: + # slice series, if it is longer than context_length + task_context = task_past_tensor[ + :, slice_idx - self.context_length : slice_idx + ] + else: + task_context = task_past_tensor[:, :slice_idx] + + # In the TEST mode, we have no target available and the task_future_covariates can be directly used + # In the TRAIN and VALIDATION modes, the target and task_future_covariates need to be constructed from + # the task_context_tensor by slicing the appropriate indices which we do below + if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]: + # the first task_n_targets elements in task_context_tensor are the targets + task_future_target = task_past_tensor[ + :, slice_idx : slice_idx + self.prediction_length + ].clone() + # mask out all rows corresponding to covariates + task_future_target[task_n_targets:] = torch.nan + + if task_n_future_covariates > 0: + # the last task_n_future_covariates elements in task_context_tensor are the known covariates + task_future_covariates = task_past_tensor[ + -task_n_future_covariates:, + slice_idx : slice_idx + self.prediction_length, + ] + else: + # zero-length tensor for easy concatenation later + task_future_covariates = torch.zeros((0, self.prediction_length)) + + # the leading task_n_targets + task_n_past_only_covariates elements are masked because the target(s) + # and past-only covariates are not known into the future + task_future_covariates_padding = torch.full( + (task_n_targets + task_n_past_only_covariates, self.prediction_length), + fill_value=torch.nan, + ) + task_future_covariates = torch.cat( + [task_future_covariates_padding, task_future_covariates], dim=0 + ) + else: + task_future_target = None + task_future_covariates = task_future_tensor + + # task_context: (task_n_targets + task_n_covariates, min(context_length, history_length)) + # task_future_target: (task_n_targets + task_n_covariates, prediction_length), the future values of known future covariates + # are ignored during loss computation + # task_future_covariates: (task_n_targets + task_n_past_only_covariates + task_n_future_covariates, prediction_length), + # the entries corresponding to targets and past-only covariates are NaNs + + return task_context, task_future_target, task_future_covariates, task_n_targets + + def _build_batch( + self, task_indices: list[int] + ) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]: + """Build a batch from given task indices.""" + batch_context_tensor_list = [] + batch_future_target_tensor_list = [] + batch_future_covariates_tensor_list = [] + batch_group_ids_list = [] + target_idx_ranges: list[tuple[int, int]] = [] + + target_start_idx = 0 + for group_id, task_idx in enumerate(task_indices): + task_context, task_future_target, task_future_covariates, task_n_targets = ( + self._construct_slice(task_idx) + ) + + group_size = task_context.shape[0] + task_group_ids = torch.full((group_size,), fill_value=group_id) + batch_context_tensor_list.append(task_context) + batch_future_target_tensor_list.append(task_future_target) + batch_future_covariates_tensor_list.append(task_future_covariates) + batch_group_ids_list.append(task_group_ids) + target_idx_ranges.append( + (target_start_idx, target_start_idx + task_n_targets) + ) + target_start_idx += group_size + + return { + "context": left_pad_and_cat_2D(batch_context_tensor_list), + "future_target": ( + None + if self.mode == DatasetMode.TEST + else torch.cat( + cast(list[torch.Tensor], batch_future_target_tensor_list), dim=0 + ) + ), + "future_covariates": torch.cat(batch_future_covariates_tensor_list, dim=0), + "group_ids": torch.cat(batch_group_ids_list, dim=0), + "num_output_patches": self.num_output_patches, + "target_idx_ranges": target_idx_ranges, + } + + def _generate_train_batches(self): + while True: + current_batch_size = 0 + task_indices = [] + + while current_batch_size < self.batch_size: + task_idx = np.random.randint(len(self.tasks)) + task_indices.append(task_idx) + current_batch_size += self.tasks[task_idx][0].shape[0] + + yield self._build_batch(task_indices) + + def _generate_sequential_batches(self): + task_idx = 0 + while task_idx < len(self.tasks): + current_batch_size = 0 + task_indices = [] + + while task_idx < len(self.tasks) and current_batch_size < self.batch_size: + task_indices.append(task_idx) + current_batch_size += self.tasks[task_idx][0].shape[0] + task_idx += 1 + + yield self._build_batch(task_indices) + + def __iter__(self) -> Iterator: + """ + Generate batches of data for the Chronos-2 model. In training mode, this iterator is infinite. + + Yields + ------ + dict + A dictionary containing: + - context: torch.Tensor of shape (batch_size, context_length) containing input sequences + - future_target: torch.Tensor of shape (batch_size, prediction_length) containing future target sequences, None in TEST mode + - future_covariates: torch.Tensor of shape (batch_size, prediction_length) containing known future covariates + - group_ids: torch.Tensor of shape (batch_size,) containing the group ID for each sequence + - num_output_patches: int indicating number of patches the model should output to cover prediction_length + - target_idx_ranges: (only in TEST mode) list of tuples indicating the start & end indices of targets in context + """ + if self.mode == DatasetMode.TRAIN: + for batch in self._generate_train_batches(): + batch.pop("target_idx_ranges") + yield batch + elif self.mode == DatasetMode.VALIDATION: + for batch in self._generate_sequential_batches(): + batch.pop("target_idx_ranges") + yield batch + else: + yield from self._generate_sequential_batches() + + @classmethod + def convert_inputs( + cls, + inputs: ( + TensorOrArray + | Sequence[TensorOrArray] + | Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]] + ), + context_length: int, + prediction_length: int, + batch_size: int, + output_patch_size: int, + min_past: int = 1, + mode: str | DatasetMode = DatasetMode.TRAIN, + ) -> "Chronos2Dataset": + """Convert from different input formats to a Chronos2Dataset.""" + if isinstance(inputs, (torch.Tensor, np.ndarray)): + inputs = convert_tensor_input_to_list_of_dicts_input(inputs) + elif isinstance(inputs, list) and all( + [isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs] + ): + inputs = cast(list[TensorOrArray], inputs) + inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs) + elif isinstance(inputs, list) and all([isinstance(x, dict) for x in inputs]): + pass + else: + raise ValueError("Unexpected inputs format") + + inputs = cast(list[dict[str, TensorOrArray | dict[str, TensorOrArray]]], inputs) + + return cls( + inputs, + context_length=context_length, + prediction_length=prediction_length, + batch_size=batch_size, + output_patch_size=output_patch_size, + min_past=min_past, + mode=mode, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/layers.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/layers.py new file mode 100644 index 0000000000000..818e3fc30cf1f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/layers.py @@ -0,0 +1,481 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ModelOutput + +from iotdb.ainode.core.model.chronos2.config import Chronos2CoreConfig + + +class RoPE(nn.Module): + """Applies rotary position embeddings (RoPE) to input tensors. + + Implementation adapted from: + https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95 + """ + + def __init__(self, dim: int, base: float = 10000): + super().__init__() + + self.dim = dim + self.base = base + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) + ) + self.inv_freq: torch.Tensor # type hint for type checker + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + @staticmethod + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + @staticmethod + def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (RoPE.rotate_half(q) * sin) + k_embed = (k * cos) + (RoPE.rotate_half(k) * sin) + return q_embed, k_embed + + +class Chronos2LayerNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# This is how transformers keeps track of LayerNorm classes ¯\_(ツ)_/¯ +ALL_LAYERNORM_LAYERS.append(Chronos2LayerNorm) # type: ignore + + +class MLP(nn.Module): + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + + assert not config.is_gated_act, "gated activations are unsupported" + self.mlp: nn.Module = MLP(config) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.mlp(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +@dataclass +class AttentionOutput(ModelOutput): + hidden_states: torch.Tensor | None = None + attn_weights: torch.Tensor | None = None + + +class MHA(nn.Module): + """Multi-head Attention Layer""" + + def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): + super().__init__() + self.d_model: int = config.d_model + self.kv_proj_dim: int = config.d_kv + self.n_heads: int = config.num_heads + self.dropout: float = config.dropout_rate + self.inner_dim: int = self.n_heads * self.kv_proj_dim + self.config = config + + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + self.use_rope = use_rope + if use_rope: + self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) + + def _eager_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Eager attention implementation using manual matmul. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: [batch, n_heads, q_len, kv_len] + """ + # Compute attention weights (no scaling - this is the original Chronos-2 implementation) + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # "bnqd,bnkd->bnqk" + scores += mask + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + return attn_output, attn_weights + + def _sdpa_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + mask: torch.Tensor, + ) -> tuple[torch.Tensor, None]: + """SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention. + + Args: + query_states: [batch, n_heads, seq_len, kv_proj_dim] + key_states: [batch, n_heads, seq_len, kv_proj_dim] + value_states: [batch, n_heads, seq_len, kv_proj_dim] + mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid) + + Returns: + attn_output: [batch, n_heads, seq_len, kv_proj_dim] + attn_weights: None (SDPA doesn't return weights) + """ + attn_output = nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + scale=1.0, # Match eager implementation (no scaling) + ) + + return attn_output, None + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + encoder_states: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> AttentionOutput: + """Multi-head attention forward pass. + + Args: + hidden_states : Input tensor of shape [batch_size, seq_len, d_model] + mask : Attention mask tensor of shape [batch_size, num_heads, q_len, kv_len] + encoder_states : Encoder states for cross-attention. Defaults to None. + position_ids : Position IDs for RoPE. Defaults to None. + output_attentions : Whether to return attention weights. Defaults to False. + + Returns: + AttentionOutput: Contains: + - hidden_states : Output tensor of shape [batch_size, seq_len, d_model] + - attn_weights : Attention weights if output_attentions=True + """ + if self.use_rope: + assert ( + position_ids is not None + ), "position_ids must be provided when self.use_rope=True" + + # Force eager attention if output_attentions is True (only eager returns weights) + attn_implementation = self.config._attn_implementation + if output_attentions: + attn_implementation = "eager" + + seq_length = hidden_states.shape[1] + + def shape(states: torch.Tensor) -> torch.Tensor: + """(batch, seq_len, inner_dim) -> (batch, n_heads, seq_len, kv_proj_dim)""" + return rearrange( + states, + "b s (h d) -> b h s d", + h=self.n_heads, + s=seq_length, + d=self.kv_proj_dim, + ) + + def unshape(states: torch.Tensor) -> torch.Tensor: + """(batch, n_heads, seq_len, kv_proj_dim) -> (batch, seq_len, inner_dim)""" + return rearrange( + states, + "b h s d -> b s (h d)", + h=self.n_heads, + s=seq_length, + d=self.kv_proj_dim, + ) + + # Construct query states + query_states = shape(self.q(hidden_states)) + is_cross_attention = encoder_states is not None + + # Construct key/value states + if is_cross_attention: + key_states = shape(self.k(encoder_states)) + value_states = shape(self.v(encoder_states)) + else: + key_states = shape(self.k(hidden_states)) + value_states = shape(self.v(hidden_states)) + if self.use_rope: + cos, sin = self.rope_embed(value_states, position_ids) + query_states, key_states = RoPE.apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if attn_implementation == "sdpa": + attn_output, attn_weights = self._sdpa_attention( + query_states, key_states, value_states, mask + ) + else: # eager + attn_output, attn_weights = self._eager_attention( + query_states, key_states, value_states, mask + ) + + # Project attention output + attn_output = unshape(attn_output) + attn_output = self.o(attn_output) + + return AttentionOutput( + hidden_states=attn_output, + attn_weights=attn_weights if output_attentions else None, + ) + + +class TimeSelfAttention(nn.Module): + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + self.self_attention = MHA(config, use_rope=True) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + output_attentions: bool = False, + ) -> AttentionOutput: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output: AttentionOutput = self.self_attention( + normed_hidden_states, + position_ids=position_ids, + mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + + return AttentionOutput( + hidden_states=hidden_states, attn_weights=attention_output.attn_weights + ) + + +class TimeCrossAttention(nn.Module): + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + self.cross_attention = MHA(config, use_rope=False) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + encoder_states: torch.Tensor, + output_attentions: bool = False, + ) -> AttentionOutput: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output: AttentionOutput = self.cross_attention( + normed_hidden_states, + mask=attention_mask, + encoder_states=encoder_states, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + + return AttentionOutput( + hidden_states=hidden_states, attn_weights=attention_output.attn_weights + ) + + +class GroupSelfAttention(nn.Module): + """Self-attention applied along the batch axis masked by the group attention mask""" + + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + # we don't use RoPE here because there's no natural ordering along the batch axis + self.self_attention = MHA(config, use_rope=False) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + ) -> AttentionOutput: + # flip time and batch axes because attention operates along dim=-2 + hidden_states = rearrange(hidden_states, "batch time d -> time batch d") + normed_hidden_states = self.layer_norm(hidden_states) + attention_output: AttentionOutput = self.self_attention( + normed_hidden_states, + mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + # flip time and batch axes back to their original position + hidden_states = rearrange(hidden_states, "time batch d -> batch time d") + + return AttentionOutput( + hidden_states=hidden_states, attn_weights=attention_output.attn_weights + ) + + +class ResidualBlock(nn.Module): + """A generic residual block which can be used for input and output embedding layers""" + + def __init__( + self, + in_dim: int, + h_dim: int, + out_dim: int, + act_fn_name: str, + dropout_p: float = 0.0, + use_layer_norm: bool = False, + ) -> None: + super().__init__() + + self.dropout = nn.Dropout(dropout_p) + self.hidden_layer = nn.Linear(in_dim, h_dim) + self.act = ACT2FN[act_fn_name] + self.output_layer = nn.Linear(h_dim, out_dim) + self.residual_layer = nn.Linear(in_dim, out_dim) + + self.use_layer_norm = use_layer_norm + if use_layer_norm: + self.layer_norm = Chronos2LayerNorm(out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hid = self.act(self.hidden_layer(x)) + out = self.dropout(self.output_layer(hid)) + res = self.residual_layer(x) + + out = out + res + + if self.use_layer_norm: + return self.layer_norm(out) + return out diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/model.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/model.py new file mode 100644 index 0000000000000..ae33015c13666 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/model.py @@ -0,0 +1,909 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import copy +from dataclasses import dataclass +from typing import cast + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from iotdb.ainode.core.model.chronos2.chronos_bolt import InstanceNorm, Patch +from iotdb.ainode.core.model.chronos2.config import ( + Chronos2CoreConfig, + Chronos2ForecastingConfig, +) +from iotdb.ainode.core.model.chronos2.layers import ( + MHA, + MLP, + AttentionOutput, + Chronos2LayerNorm, + FeedForward, + GroupSelfAttention, + ResidualBlock, + TimeSelfAttention, +) + + +@dataclass +class Chronos2EncoderBlockOutput(ModelOutput): + hidden_states: torch.Tensor | None = None + time_self_attn_weights: torch.Tensor | None = None + group_self_attn_weights: torch.Tensor | None = None + + +class Chronos2EncoderBlock(nn.Module): + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + assert not config.is_decoder + + self.layer = nn.ModuleList() + self.layer.append(TimeSelfAttention(config)) + self.layer.append(GroupSelfAttention(config)) + self.layer.append(FeedForward(config)) + + def forward( + self, + hidden_states: torch.Tensor, + *, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + group_time_mask: torch.Tensor, + output_attentions: bool = False, + ) -> Chronos2EncoderBlockOutput: + # apply time attention + time_self_attn_outputs: AttentionOutput = self.layer[0]( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = time_self_attn_outputs[0] + + # apply group attention + group_self_attn_outputs: AttentionOutput = self.layer[1]( + hidden_states, + attention_mask=group_time_mask, + output_attentions=output_attentions, + ) + hidden_states = group_self_attn_outputs[0] + + # apply feed forward layer + hidden_states = self.layer[2](hidden_states) + + return Chronos2EncoderBlockOutput( + hidden_states=hidden_states, + time_self_attn_weights=time_self_attn_outputs.attn_weights, + group_self_attn_weights=group_self_attn_outputs.attn_weights, + ) + + +@dataclass +class Chronos2EncoderOutput(ModelOutput): + last_hidden_state: torch.Tensor | None = None + all_time_self_attn_weights: tuple[torch.Tensor, ...] | None = None + all_group_self_attn_weights: tuple[torch.Tensor, ...] | None = None + + +class Chronos2Encoder(nn.Module): + def __init__(self, config: Chronos2CoreConfig): + super().__init__() + assert not config.is_decoder + + self.block = nn.ModuleList( + [Chronos2EncoderBlock(config) for i in range(config.num_layers)] + ) + self.final_layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + @staticmethod + def _expand_and_invert_time_attention_mask( + attention_mask: torch.Tensor, floating_type: torch.dtype + ) -> torch.Tensor: + assert ( + attention_mask.ndim == 2 + ), "attention_mask must have shape (batch, seq_len)" + + # Add new dims for attention heads and q_len + attention_mask = attention_mask[:, None, None, :] + + # Invert binary mask to float mask which can be added to attention scores + attention_mask = attention_mask.to(dtype=floating_type) + attention_mask = (1.0 - attention_mask) * torch.finfo(floating_type).min + return attention_mask + + @staticmethod + def _construct_and_invert_group_time_mask( + group_ids: torch.Tensor, + attention_mask: torch.Tensor, + floating_type: torch.dtype, + ) -> torch.Tensor: + # construct group_mask (batch, batch) from group ids + # a cell is True if both row and col had the same group id + group_mask = group_ids[:, None] == group_ids[None, :] + # outer product of group_mask and attention_mask (time_mask) + # group_time_mask combines group and time masks to ensure that attention only uses + # tokens from the same group which are also not masked in time + group_time_mask = torch.einsum("qb, bt -> qbt", group_mask, attention_mask) + + if torch.is_floating_point(group_time_mask): + # this ensures that mixed precision training does not overflow + floating_type = group_time_mask.dtype + + # reshape mask to shape of attention scores + group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b") + group_time_mask = (1.0 - group_time_mask) * torch.finfo(floating_type).min + + return group_time_mask + + def forward( + self, + inputs_embeds: torch.Tensor, + *, + group_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> Chronos2EncoderOutput: + batch_size, seq_length = inputs_embeds.size()[:-1] + + if position_ids is None: + position_ids = torch.arange( + 0, seq_length, dtype=torch.long, device=inputs_embeds.device + ).unsqueeze(0) + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, + seq_length, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + + # make the time attention mask broadcastable to attention scores (batch, n_heads, q_len, kv_len) and invert + extended_attention_mask = self._expand_and_invert_time_attention_mask( + attention_mask, inputs_embeds.dtype + ) + + # construct group time mask + group_time_mask = self._construct_and_invert_group_time_mask( + group_ids, attention_mask, inputs_embeds.dtype + ) + + all_time_self_attentions: tuple[torch.Tensor, ...] = () + all_group_self_attentions: tuple[torch.Tensor, ...] = () + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module) in enumerate(self.block): + layer_outputs: Chronos2EncoderBlockOutput = layer_module( + hidden_states, + position_ids=position_ids, + attention_mask=extended_attention_mask, + group_time_mask=group_time_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + assert layer_outputs.time_self_attn_weights is not None + assert layer_outputs.group_self_attn_weights is not None + + all_time_self_attentions = ( + *all_time_self_attentions, + layer_outputs.time_self_attn_weights, + ) + all_group_self_attentions = ( + *all_group_self_attentions, + layer_outputs.group_self_attn_weights, + ) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + return Chronos2EncoderOutput( + last_hidden_state=hidden_states, + all_time_self_attn_weights=all_time_self_attentions, + all_group_self_attn_weights=all_group_self_attentions, + ) + + +@dataclass +class Chronos2Output(ModelOutput): + loss: torch.Tensor | None = None + quantile_preds: torch.Tensor | None = None + enc_time_self_attn_weights: tuple[torch.Tensor, ...] | None = None + enc_group_self_attn_weights: tuple[torch.Tensor, ...] | None = None + + +class Chronos2Model(PreTrainedModel): + config_class = Chronos2CoreConfig # type: ignore[assignment] + _supports_long_horizon: bool = True + _supports_future_covariates: bool = True + _supports_sdpa: bool = True + + def __init__(self, config: Chronos2CoreConfig): + assert hasattr(config, "chronos_config"), "Not a valid Chronos config" + + super().__init__(config) + self.config: Chronos2CoreConfig + self.model_dim = config.d_model + + config.chronos_config["time_encoding_scale"] = config.chronos_config.get( + "time_encoding_scale", config.chronos_config["context_length"] + ) + self.chronos_config = Chronos2ForecastingConfig(**config.chronos_config) + + assert ( + self.chronos_config.input_patch_size + == self.chronos_config.output_patch_size + ), ( + "input_patch_size and output_patch_size sizes must be equal, " + f"but found {self.chronos_config.input_patch_size} and {self.chronos_config.output_patch_size}" + ) + + # Only [PAD] token (and [REG] token) + if self.chronos_config.use_reg_token: + config.reg_token_id = 1 + + config.vocab_size = 2 if self.chronos_config.use_reg_token else 1 + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + # Input patch embedding layer + self.input_patch_embedding = ResidualBlock( + # x3 for [time_embedding, patch, patch_mask] + in_dim=self.chronos_config.input_patch_size * 3, + h_dim=config.d_ff, + out_dim=config.d_model, + act_fn_name=config.dense_act_fn, + dropout_p=config.dropout_rate, + ) + + # patching layer + self.patch = Patch( + patch_size=self.chronos_config.input_patch_size, + patch_stride=self.chronos_config.input_patch_stride, + ) + + # instance normalization, also referred to as "scaling" in Chronos and GluonTS + self.instance_norm = InstanceNorm(use_arcsinh=self.chronos_config.use_arcsinh) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + self.encoder = Chronos2Encoder(encoder_config) + + self.num_quantiles = len(self.chronos_config.quantiles) + quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype) + self.quantiles: torch.Tensor + self.register_buffer("quantiles", quantiles, persistent=False) + + self.output_patch_embedding = ResidualBlock( + in_dim=config.d_model, + h_dim=config.d_ff, + out_dim=self.num_quantiles * self.chronos_config.output_patch_size, + act_fn_name=config.dense_act_fn, + dropout_p=config.dropout_rate, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _init_weights(self, module): + super()._init_weights(module) + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, Chronos2LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, MLP): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, MHA): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + kv_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5) + ) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5) + ) + elif isinstance(module, (Chronos2Model)): + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, ResidualBlock): + module.hidden_layer.weight.data.normal_( + mean=0.0, + std=factor * (module.hidden_layer.weight.size(-1) ** -0.5), + ) + if ( + hasattr(module.hidden_layer, "bias") + and module.hidden_layer.bias is not None + ): + module.hidden_layer.bias.data.zero_() + + module.residual_layer.weight.data.normal_( + mean=0.0, + std=factor * (module.residual_layer.weight.size(-1) ** -0.5), + ) + if ( + hasattr(module.residual_layer, "bias") + and module.residual_layer.bias is not None + ): + module.residual_layer.bias.data.zero_() + + module.output_layer.weight.data.normal_( + mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5) + ) + if ( + hasattr(module.output_layer, "bias") + and module.output_layer.bias is not None + ): + module.output_layer.bias.data.zero_() + + def _validate_input( + self, + context: torch.Tensor, + context_mask: torch.Tensor | None, + group_ids: torch.Tensor | None, + future_covariates: torch.Tensor | None, + future_covariates_mask: torch.Tensor | None, + num_output_patches: int, + future_target: torch.Tensor | None, + future_target_mask: torch.Tensor | None, + ): + output_patch_size = self.chronos_config.output_patch_size + if context.ndim != 2: + raise ValueError( + f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}" + ) + if context_mask is not None and context_mask.shape != context.shape: + raise ValueError( + f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}" + ) + if future_covariates is not None: + if ( + future_covariates.shape[0] != context.shape[0] + or future_covariates.ndim != 2 + ): + raise ValueError( + f"future_covariates must have shape (batch_size={context.shape[0]}, future_length), found: {tuple(future_covariates.shape)}" + ) + if future_covariates.shape[-1] > num_output_patches * output_patch_size: + raise ValueError( + f"{num_output_patches=} must be large enough to accommodate the length of future_covariates, " + f"found: {future_covariates.shape[-1]} > {num_output_patches} * {output_patch_size}" + ) + if ( + future_target is not None + and future_target.shape != future_covariates.shape + ): + raise ValueError( + f"future_target must have the same shape as future_covariates, found: {tuple(future_target.shape)} and {tuple(future_covariates.shape)}" + ) + if future_covariates_mask is not None: + if future_covariates is None: + raise ValueError( + "future_covariates must be provided if future_covariates_mask is provided" + ) + if future_covariates_mask.shape != future_covariates.shape: + raise ValueError( + f"future_covariates_mask must have the same shape as future_covariates, " + f"found: {tuple(future_covariates_mask.shape)} and {tuple(future_covariates.shape)}" + ) + if group_ids is not None and group_ids.shape != (context.shape[0],): + raise ValueError( + f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}" + ) + if future_target is not None: + if future_target.shape[0] != context.shape[0] or future_target.ndim != 2: + raise ValueError( + f"future_target must have shape (batch_size={context.shape[0]}, future_length), found: {tuple(future_target.shape)}" + ) + if future_target.shape[-1] > output_patch_size * num_output_patches: + raise ValueError( + f"{num_output_patches=} must be large enough to accommodate the length of future_target, " + f"found: {future_target.shape[-1]} > {num_output_patches} * {output_patch_size}" + ) + if future_target_mask is not None: + if future_target is None: + raise ValueError( + "future_target must be provided if future_target_mask is provided" + ) + if future_target_mask.shape != future_target.shape: + raise ValueError( + f"future_target_mask must have the same shape as future_target, found: {tuple(future_target_mask.shape)} and {tuple(future_target.shape)}" + ) + + def _prepare_patched_context( + self, context: torch.Tensor, context_mask: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + context_mask = ( + context_mask.to(context.dtype) + if context_mask is not None + else torch.isnan(context).logical_not().to(context.dtype) + ) + + batch_size, context_length = context.shape + # truncate context if it's longer than model's context length + if context_length > self.chronos_config.context_length: + context = context[..., -self.chronos_config.context_length :] + context_mask = context_mask[..., -self.chronos_config.context_length :] + + # scaling + context, loc_scale = self.instance_norm(context) + + # scaling is done in 32-bit precision, then the context is moved to model's dtype + context = context.to(self.dtype) + context_mask = context_mask.to(self.dtype) + + # patching + patched_context = self.patch(context) + patched_mask = torch.nan_to_num(self.patch(context_mask), nan=0.0) + patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0) + + # attention_mask = 1 if at least one item in the patch is observed + attention_mask = patched_mask.sum(dim=-1) > 0 # (batch_size, num_patches) + num_context_patches = attention_mask.shape[-1] + + # context time encoding: every observation is assigned a sequential time index, + # scaled by model's context length = [-C, -(C-1), ..., -1] / context_length + final_context_length = ( + num_context_patches * self.chronos_config.input_patch_size + ) + context_time_enc = torch.arange( + start=-final_context_length, end=0, device=self.device, dtype=torch.float32 + ) + context_time_enc = ( + repeat( + context_time_enc, + "(n p) -> b n p", + b=batch_size, + n=num_context_patches, + p=self.chronos_config.input_patch_size, + ) + .div(cast(int, self.chronos_config.time_encoding_scale)) + .to(self.dtype) + ) + + # concat time encoding, context and mask along the last (feature) dim + patched_context = torch.cat( + [context_time_enc, patched_context, patched_mask], dim=-1 + ) + + return patched_context, attention_mask, loc_scale + + def _prepare_patched_future( + self, + future_covariates: torch.Tensor | None, + future_covariates_mask: torch.Tensor | None, + loc_scale: tuple[torch.Tensor, torch.Tensor], + num_output_patches: int, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + output_patch_size = self.chronos_config.output_patch_size + if future_covariates is not None: + future_covariates, _ = self.instance_norm(future_covariates, loc_scale) + future_covariates = cast(torch.Tensor, future_covariates) + future_covariates = future_covariates.to(self.dtype) + + if future_covariates_mask is None: + future_covariates_mask = ( + torch.isnan(future_covariates) + .logical_not() + .to(future_covariates.dtype) + ) + + future_covariates = torch.where( + future_covariates_mask > 0.0, future_covariates, 0.0 + ) + + if torch.isnan(future_covariates).any(): + raise ValueError( + "future_covariates contains NaN values at indices not masked by future_covariates_mask. " + "Input the correct future_covariates_mask or omit it to automatically infer the mask based on NaN values." + ) + + # add padding if the length of future_covariates is not an integer multiple of output_patch_size + if num_output_patches * output_patch_size > future_covariates.shape[-1]: + padding_shape = ( + *future_covariates.shape[:-1], + num_output_patches * output_patch_size + - future_covariates.shape[-1], + ) + future_covariates = torch.cat( + [ + future_covariates, + torch.zeros(padding_shape).to(future_covariates), + ], + dim=-1, + ) + future_covariates_mask = torch.cat( + [ + future_covariates_mask, + torch.zeros(padding_shape).to(future_covariates_mask), + ], + dim=-1, + ) + + patched_future_covariates = rearrange( + future_covariates, + "b (n p) -> b n p", + n=num_output_patches, + p=output_patch_size, + ) + patched_future_covariates_mask = rearrange( + future_covariates_mask, + "b (n p) -> b n p", + n=num_output_patches, + p=output_patch_size, + ) + else: + patched_future_covariates = torch.zeros( + batch_size, + num_output_patches, + output_patch_size, + device=self.device, + dtype=self.dtype, + ) + patched_future_covariates_mask = torch.zeros( + batch_size, + num_output_patches, + output_patch_size, + device=self.device, + dtype=self.dtype, + ) + + # future time encoding: every future timestep is assigned a sequential time index, + # scaled by model's context length = [0, 1, ..., h-1] / context_length + final_future_length = num_output_patches * output_patch_size + future_time_enc = torch.arange( + start=0, end=final_future_length, device=self.device, dtype=torch.float32 + ) + future_time_enc = ( + repeat( + future_time_enc, + "(n p) -> b n p", + b=batch_size, + n=num_output_patches, + p=output_patch_size, + ) + .div(cast(int, self.chronos_config.time_encoding_scale)) + .to(self.dtype) + ) + + patched_future = torch.cat( + [ + future_time_enc, + patched_future_covariates, + patched_future_covariates_mask, + ], + dim=-1, + ) + + return patched_future, patched_future_covariates_mask + + def _compute_loss( + self, + quantile_preds: torch.Tensor, + future_target: torch.Tensor, + future_target_mask: torch.Tensor | None, + patched_future_covariates_mask: torch.Tensor, + loc_scale: tuple[torch.Tensor, torch.Tensor], + num_output_patches: int, + ) -> torch.Tensor: + batch_size = future_target.shape[0] + output_patch_size = self.chronos_config.output_patch_size + assert ( + quantile_preds.shape[0] == batch_size + and quantile_preds.shape[-1] >= future_target.shape[-1] + ) + + # normalize target and mask + future_target, _ = self.instance_norm(future_target, loc_scale) + future_target = future_target.unsqueeze(1) + future_target = future_target.to(self.device) + future_target_mask = ( + future_target_mask.unsqueeze(1).to(self.device) + if future_target_mask is not None + else ~torch.isnan(future_target) + ) + future_target = torch.where(future_target_mask > 0.0, future_target, 0.0) + + # pad target and target_mask if they are shorter than model's prediction + if quantile_preds.shape[-1] > future_target.shape[-1]: + padding_shape = ( + *future_target.shape[:-1], + quantile_preds.shape[-1] - future_target.shape[-1], + ) + future_target = torch.cat( + [future_target, torch.zeros(padding_shape).to(future_target)], dim=-1 + ) + future_target_mask = torch.cat( + [future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], + dim=-1, + ) + + quantiles = rearrange(self.quantiles, "num_quantiles -> 1 num_quantiles 1") + quantile_loss = 2 * torch.abs( + (future_target - quantile_preds) + * ((future_target <= quantile_preds).float() - quantiles) + ) + inv_future_covariate_mask = 1 - rearrange( + patched_future_covariates_mask, + "b n p -> b 1 (n p)", + b=batch_size, + n=num_output_patches, + p=output_patch_size, + ) + # the first components masks any missing targets and the second component masks known future values + loss_mask = future_target_mask.float() * inv_future_covariate_mask + loss = quantile_loss * loss_mask + # mean over prediction horizon, sum over quantile levels and mean over batch + loss = loss.mean(dim=-1).sum(dim=-1).mean() + + return loss + + def encode( + self, + context: torch.Tensor, + context_mask: torch.Tensor | None = None, + group_ids: torch.Tensor | None = None, + future_covariates: torch.Tensor | None = None, + future_covariates_mask: torch.Tensor | None = None, + num_output_patches: int = 1, + future_target: torch.Tensor | None = None, + future_target_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ): + self._validate_input( + context=context, + context_mask=context_mask, + future_covariates=future_covariates, + future_covariates_mask=future_covariates_mask, + group_ids=group_ids, + num_output_patches=num_output_patches, + future_target=future_target, + future_target_mask=future_target_mask, + ) + + batch_size = context.shape[0] + patched_context, attention_mask, loc_scale = self._prepare_patched_context( + context=context, context_mask=context_mask + ) + num_context_patches = attention_mask.shape[-1] + + # get input embeddings of shape (batch, num_context_patches, d_model) + input_embeds: torch.Tensor = self.input_patch_embedding(patched_context) + # append [REG] special token embedding, if needed + if self.chronos_config.use_reg_token: + reg_input_ids = torch.full( + (batch_size, 1), self.config.reg_token_id, device=input_embeds.device + ) + reg_embeds = self.shared(reg_input_ids) + input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2) + attention_mask = torch.cat( + [ + attention_mask.to(self.dtype), + torch.ones_like(reg_input_ids).to(self.dtype), + ], + dim=-1, + ) + + patched_future, patched_future_covariates_mask = self._prepare_patched_future( + future_covariates=future_covariates, + future_covariates_mask=future_covariates_mask, + loc_scale=loc_scale, + num_output_patches=num_output_patches, + batch_size=batch_size, + ) + future_attention_mask = torch.ones( + batch_size, num_output_patches, dtype=self.dtype, device=self.device + ) + + # get future embeddings of shape (batch, num_output_patches, d_model) + future_embeds: torch.Tensor = self.input_patch_embedding(patched_future) + + # concatenate context and future embeddings and masks + input_embeds = torch.cat([input_embeds, future_embeds], dim=-2) + attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1) + + if group_ids is None: + # by default, each time series is treated independently, i.e., no mixing across the batch + group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device) + + encoder_outputs: Chronos2EncoderOutput = self.encoder( + attention_mask=attention_mask, + inputs_embeds=input_embeds, + group_ids=group_ids, + output_attentions=output_attentions, + ) + return ( + encoder_outputs, + loc_scale, + patched_future_covariates_mask, + num_context_patches, + ) + + def forward( + self, + context: torch.Tensor, + context_mask: torch.Tensor | None = None, + group_ids: torch.Tensor | None = None, + future_covariates: torch.Tensor | None = None, + future_covariates_mask: torch.Tensor | None = None, + num_output_patches: int = 1, + future_target: torch.Tensor | None = None, + future_target_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> Chronos2Output: + """Forward pass of the Chronos2 model. + + Parameters + ---------- + context + Input tensor of shape (batch_size, context_length) containing the historical values + context_mask + Binary mask tensor of same shape as context indicating which values are valid (1) vs missing (0) + If missing, the context_mask will be automatically constructed based on the NaN values in context. + group_ids : torch.Tensor | None, optional + Group IDs of shape (batch_size,) indicating which times series in the batch form a group. + A group indicates a task, for example, for a batch of size 6: + - if groups_ids = [0, 1, 2, 3, 4, 5], each time series is treated independently. + - if groups_ids = [0, 0, 1, 1, 1, 2], information is mixed across the first two time series (id=0), + the next three time series (id=1) and the last time series is treated separately. Information is + NOT shared among time series from different groups. + The ordering and specific values of group_ids are not important, all time series with the same group + ID form a group. + future_covariates + Tensor of shape (batch_size, future_length) containing future covariates. Note that the size of + tensor along the first axis is equal to the batch_size. This means that future values (which may be NaNs) + must be provided for each time series in the batch. For any time series that need to be forecasted, the + future_covariates can be set to NaNs, if ``future_covariates_mask`` is omitted or to an arbitrary dummy + value when ``future_covariates_mask`` is provided. ``future_covariates`` can be used with ``group_ids`` + to construct heterogenous forecasting tasks in a single batch. For example: + - future_covariates = [[nan, ...], [nan, ...], [v1, ...], [v2, ...], [nan, ...], [nan, ...]] + - groups_ids = [0, 0, 1, 1, 1, 2] + - future_covariates_mask = None + contains 3 types of forecasting tasks: + - [0, 0]: The first task, both future_covariates are missing, which implies that the two time series need to + be forecasted jointly, i.e., multivariate forecasting. + - [1, 1, 1]: In the next task, the first two future_covariates are available and the last one is missing + ([v1, ...], [v2, ...], [nan, ...]), where [v1, ...] and [v1, ...] denote an arbitrary sequence of values. + This indicates that the first two time series are known covariates and the third one needs to be forecasted + by the model. + - [2]: The last task has a single time series in the group which needs to be forecasted independently. + There is no theoretical limit on the number of time series in a group, i.e., the number of targets and known + covariates in a task. The above setup subsumes tasks with past-only covariates as the model's prediction for + those time series can simply be ignored downstream. + future_covariates_mask + Binary mask tensor of same shape as future_covariates indicating which future values are known + If omitted, future_covariates_mask is automatically constructed based on future_covariates with + all non-NaN values treated as known future values. + num_output_patches + Number of output patches to generate predictions for, by default 1 + When ``future_covariates`` and/or ``future_target`` are provided, num_output_patches should be large enough to accommodate + their lengths, i.e., num_output_patches * output_patch_size >= future_length + future_target + Target tensor of shape (batch_size, future_length) used during training. If ``future_covariates`` are provided, both + target and future_covariates must have the same shape. + future_target_mask + Binary mask tensor of same shape as `future_target` indicating which values are valid (1) vs missing (0) + If missing, the `future_target_mask` will be automatically constructed based on the NaN values in `future_target`. + output_attentions + Whether to return attention weights, by default False + + Returns + ------- + Chronos2Output containing: + - loss: Training loss, if `future_target` is provided + - quantile_preds: Quantile predictions of shape (batch_size, num_quantiles, num_output_patches * output_patch_size). + quantile_preds will contain an entry for every time series in the context batch regardless of whether it was a + known future covariate. + - enc_time_self_attn_weights: Time self attention weights, if output_attentions=True + - enc_group_self_attn_weights: Group self attention weights, if output_attentions=True + """ + batch_size = context.shape[0] + ( + encoder_outputs, + loc_scale, + patched_future_covariates_mask, + num_context_patches, + ) = self.encode( + context=context, + context_mask=context_mask, + group_ids=group_ids, + future_covariates=future_covariates, + future_covariates_mask=future_covariates_mask, + num_output_patches=num_output_patches, + future_target=future_target, + future_target_mask=future_target_mask, + output_attentions=output_attentions, + ) + hidden_states: torch.Tensor = encoder_outputs[0] + assert hidden_states.shape == ( + batch_size, + num_context_patches + 1 + num_output_patches, + self.model_dim, + ) + + # slice the last num_output_patches hidden states to be input into the output_patch_embedding + forecast_embeds = hidden_states[:, -num_output_patches:] + quantile_preds: torch.Tensor = self.output_patch_embedding(forecast_embeds) + quantile_preds = rearrange( + quantile_preds, + "b n (q p) -> b q (n p)", + n=num_output_patches, + q=self.num_quantiles, + p=self.chronos_config.output_patch_size, + ) + + loss = ( + self._compute_loss( + quantile_preds=quantile_preds, + future_target=future_target, + future_target_mask=future_target_mask, + patched_future_covariates_mask=patched_future_covariates_mask, + loc_scale=loc_scale, + num_output_patches=num_output_patches, + ) + if future_target is not None + else None + ) + + # Unscale predictions + quantile_preds = rearrange( + quantile_preds, + "b q h -> b (q h)", + b=batch_size, + q=self.num_quantiles, + h=num_output_patches * self.chronos_config.output_patch_size, + ) + quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale) + quantile_preds = rearrange( + quantile_preds, + "b (q h) -> b q h", + q=self.num_quantiles, + h=num_output_patches * self.chronos_config.output_patch_size, + ) + + return Chronos2Output( + loss=loss, + quantile_preds=quantile_preds, + enc_time_self_attn_weights=encoder_outputs.all_time_self_attn_weights, + enc_group_self_attn_weights=encoder_outputs.all_group_self_attn_weights, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py new file mode 100644 index 0000000000000..b99a930784140 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py @@ -0,0 +1,391 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import math + +import torch +from einops import rearrange, repeat +from torch.utils.data import DataLoader + +from iotdb.ainode.core.exception import InferenceModelInternalException +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.chronos2.dataset import Chronos2Dataset, DatasetMode +from iotdb.ainode.core.model.chronos2.utils import ( + interpolate_quantiles, + weighted_quantile, +) + +logger = Logger() + + +class Chronos2Pipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def preprocess(self, inputs): + inputs = super().preprocess(inputs) + return inputs + + @property + def model_context_length(self) -> int: + return self.model.chronos_config.context_length + + @property + def model_output_patch_size(self) -> int: + return self.model.chronos_config.output_patch_size + + @property + def model_prediction_length(self) -> int: + return ( + self.model.chronos_config.max_output_patches + * self.model.chronos_config.output_patch_size + ) + + @property + def quantiles(self) -> list[float]: + return self.model.chronos_config.quantiles + + @property + def max_output_patches(self) -> int: + return self.model.chronos_config.max_output_patches + + @staticmethod + def _slide_context_and_future_covariates( + context: torch.Tensor, future_covariates: torch.Tensor, slide_by: int + ) -> tuple[torch.Tensor, torch.Tensor]: + # replace context with future_covariates, where the values of future covariates are known (not NaN) + future_covariates_slice = future_covariates[..., :slide_by] + context[..., -slide_by:] = torch.where( + torch.isnan(future_covariates_slice), + context[..., -slide_by:], + future_covariates_slice, + ) + # shift future_covariates + future_covariates = future_covariates[..., slide_by:] + + return context, future_covariates + + @staticmethod + def _get_prob_mass_per_quantile_level( + quantile_levels: torch.Tensor, + ) -> torch.Tensor: + """ + Computes normalized probability masses for quantile levels using trapezoidal rule approximation. + + Each quantile receives probability mass proportional to the width of its surrounding interval, + creating a piecewise uniform distribution. The mass for quantile q_i is computed as + (q_{i+1} - q_{i-1}) / 2, where q_0 = 0 and q_{n+1} = 1. + + Parameters + ---------- + quantile_levels : torch.Tensor + The quantile levels, must be strictly in (0, 1) + + Returns + ------- + torch.Tensor + The normalized probability mass per quantile + """ + assert quantile_levels.ndim == 1 + assert quantile_levels.min() > 0.0 and quantile_levels.max() < 1.0 + + device = quantile_levels.device + boundaries = torch.cat( + [ + torch.tensor([0.0], device=device), + quantile_levels, + torch.tensor([1.0], device=device), + ] + ) + prob_mass = (boundaries[2:] - boundaries[:-2]) / 2 + return prob_mass / prob_mass.sum() + + def _prepare_inputs_for_long_horizon_unrolling( + self, + context: torch.Tensor, + group_ids: torch.Tensor, + future_covariates: torch.Tensor, + unrolled_quantiles: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Expand the context, future_covariates and group_ids along a new "quantile" axis + if future_covariates is not None: + future_covariates = repeat( + future_covariates, "b t -> b q t", q=len(unrolled_quantiles) + ) + context = repeat(context, "b t -> b q t", q=len(unrolled_quantiles)) + group_ids = repeat(group_ids, "b -> b q", q=len(unrolled_quantiles)) + # Shift the group_ids so that mixing is enabled only for time series with the same group_id and + # at the same quantile level, e.g., if the group_ids were [0, 0, 1, 1, 1] initially, after expansion + # and shifting they will be: + # [[0, 1, 2, 3, 4, 5, 6, 7, 8], + # [0, 1, 2, 3, 4, 5, 6, 7, 8], + # [9, 10, 11, 12, 13, 14, 15, 16, 17], + # [9, 10, 11, 12, 13, 14, 15, 16, 17], + # [9, 10, 11, 12, 13, 14, 15, 16, 17]] + group_ids = group_ids * len(unrolled_quantiles) + torch.arange( + len(unrolled_quantiles), device=self.model.device + ).unsqueeze(0) + # We unroll the quantiles in unrolled_quantiles to the future and each unrolled quantile gives + # len(self.quantiles) predictions, so we end up with len(unrolled_quantiles) * len(self.quantiles) + # "samples". unrolled_sample_weights specifies the amount of probability mass covered by each sample. + # Note that this effectively leads to shrinking of the probability space but it is better heuristic + # than just using the median to unroll, which leads to uncertainty collapse. + unrolled_sample_weights = torch.outer( + self._get_prob_mass_per_quantile_level(unrolled_quantiles), + self._get_prob_mass_per_quantile_level(torch.tensor(self.quantiles)), + ) + + return context, group_ids, future_covariates, unrolled_sample_weights + + def _autoregressive_unroll_for_long_horizon( + self, + context: torch.Tensor, + group_ids: torch.Tensor, + future_covariates: torch.Tensor, + prediction: torch.Tensor, + unrolled_quantiles: torch.Tensor, + unrolled_sample_weights: torch.Tensor, + num_output_patches: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Get unrolled_quantiles from prediction and append it to the expanded context + prediction_unrolled = interpolate_quantiles( + query_quantile_levels=unrolled_quantiles, + original_quantile_levels=self.quantiles, + original_values=rearrange(prediction, "b q h -> b h q"), + ) + prediction_unrolled = rearrange(prediction_unrolled, "b h q -> b q h") + context = torch.cat([context, prediction_unrolled], dim=-1)[ + ..., -self.model_context_length : + ] + n_paths = len(unrolled_quantiles) + + # Shift future_covariates by prediction.shape[-1] while replacing the predicted values + # of future covariates in the context with their actual values, if known + if future_covariates is not None: + context, future_covariates = self._slide_context_and_future_covariates( + context=context, + future_covariates=future_covariates, + slide_by=prediction.shape[-1], + ) + + # Reshape (batch, n_paths, context_length) -> (batch * n_paths, context_length) + prediction = self._predict_step( + context=rearrange(context, "b n t -> (b n) t"), + future_covariates=( + rearrange(future_covariates, "b n t -> (b n) t") + if future_covariates is not None + else None + ), + group_ids=rearrange(group_ids, "b n -> (b n)"), + num_output_patches=num_output_patches, + ) + # Reshape predictions from (batch * n_paths, n_quantiles, length) to (batch, n_paths * n_quantiles, length) + prediction = rearrange(prediction, "(b n) q h -> b (n q) h", n=n_paths) + # Reduce `n_paths * n_quantiles` to n_quantiles and transpose back + prediction = weighted_quantile( + query_quantile_levels=self.quantiles, + sample_weights=rearrange(unrolled_sample_weights, "n q -> (n q)"), + samples=rearrange(prediction, "b (n q) h -> b h (n q)", n=n_paths), + ) + prediction = rearrange(prediction, "b h q -> b q h") + + return prediction, context, future_covariates + + def forecast(self, inputs, **infer_kwargs): + model_prediction_length = self.model_prediction_length + prediction_length = infer_kwargs.get("predict_length", 96) + # The maximum number of output patches to generate in a single forward pass before the long-horizon heuristic kicks in. Note: A value larger + # than the model's default max_output_patches may lead to degradation in forecast accuracy, defaults to a model-specific value + max_output_patches = infer_kwargs.get( + "max_output_patches", self.max_output_patches + ) + # The set of quantiles to use when making long-horizon predictions; must be a subset of the model's default quantiles. These quantiles + # are appended to the historical context and input into the model autoregressively to generate long-horizon predictions. Note that the + # effective batch size increases by a factor of `len(unrolled_quantiles)` when making long-horizon predictions, + # by default [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + unrolled_quantiles = infer_kwargs.get( + "unrolled_quantiles", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + ) + + if not set(unrolled_quantiles).issubset(self.quantiles): + raise ValueError( + f"Unrolled quantiles must be a subset of the model's quantiles. " + f"Found: {unrolled_quantiles=}, model_quantiles={self.quantiles}" + ) + unrolled_quantiles_tensor = torch.tensor(unrolled_quantiles) + + if prediction_length > model_prediction_length: + msg = ( + f"We recommend keeping prediction length <= {model_prediction_length}. " + "The quality of longer predictions may degrade since the model is not optimized for it. " + ) + logger.warning(msg) + + context_length = inputs.shape[-1] + if context_length > self.model_context_length: + logger.warning( + f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. " + f"Resetting context_length to {self.model_context_length}." + ) + context_length = self.model_context_length + + test_dataset = Chronos2Dataset.convert_inputs( + inputs=inputs, + context_length=context_length, + prediction_length=prediction_length, + batch_size=256, + output_patch_size=self.model_output_patch_size, + mode=DatasetMode.TEST, + ) + test_loader = DataLoader( + test_dataset, + batch_size=None, + pin_memory=True, + shuffle=False, + drop_last=False, + ) + + all_predictions: list[torch.Tensor] = [] + for batch in test_loader: + assert batch["future_target"] is None + batch_context = batch["context"] + batch_group_ids = batch["group_ids"] + batch_future_covariates = batch["future_covariates"] + batch_target_idx_ranges = batch["target_idx_ranges"] + + batch_prediction = self._predict_batch( + context=batch_context, + group_ids=batch_group_ids, + future_covariates=batch_future_covariates, + unrolled_quantiles_tensor=unrolled_quantiles_tensor, + prediction_length=prediction_length, + max_output_patches=max_output_patches, + target_idx_ranges=batch_target_idx_ranges, + ) + all_predictions.extend(batch_prediction) + + return all_predictions + + def _predict_batch( + self, + context: torch.Tensor, + group_ids: torch.Tensor, + future_covariates: torch.Tensor, + unrolled_quantiles_tensor: torch.Tensor, + prediction_length: int, + max_output_patches: int, + target_idx_ranges: list[tuple[int, int]], + ) -> list[torch.Tensor]: + context = context.to(device=self.model.device, dtype=torch.float32) + group_ids = group_ids.to(device=self.model.device) + future_covariates = future_covariates.to( + device=self.model.device, dtype=torch.float32 + ) + + def get_num_output_patches(remaining_horizon: int): + num_output_patches = math.ceil( + remaining_horizon / self.model_output_patch_size + ) + num_output_patches = min(num_output_patches, max_output_patches) + + return num_output_patches + + predictions = [] + remaining = prediction_length + + # predict first set of patches up to max_output_patches + prediction: torch.Tensor = self._predict_step( + context=context, + group_ids=group_ids, + future_covariates=future_covariates, + num_output_patches=get_num_output_patches(remaining), + ) + predictions.append(prediction) + remaining -= prediction.shape[-1] + + # prepare inputs for long horizon prediction + if remaining > 0: + context, group_ids, future_covariates, unrolled_sample_weights = ( + self._prepare_inputs_for_long_horizon_unrolling( + context=context, + group_ids=group_ids, + future_covariates=future_covariates, + unrolled_quantiles=unrolled_quantiles_tensor, + ) + ) + + # long horizon heuristic + while remaining > 0: + prediction, context, future_covariates = ( + self._autoregressive_unroll_for_long_horizon( + context=context, + group_ids=group_ids, + future_covariates=future_covariates, + prediction=prediction, + unrolled_quantiles=unrolled_quantiles_tensor, + unrolled_sample_weights=unrolled_sample_weights, + num_output_patches=get_num_output_patches(remaining), + ) + ) + predictions.append(prediction) + remaining -= prediction.shape[-1] + + batch_prediction = torch.cat(predictions, dim=-1)[..., :prediction_length].to( + dtype=torch.float32, device="cpu" + ) + + return [batch_prediction[start:end] for (start, end) in target_idx_ranges] + + def _predict_step( + self, + context: torch.Tensor, + group_ids: torch.Tensor, + future_covariates: torch.Tensor | None, + num_output_patches: int, + ) -> torch.Tensor: + kwargs = {} + if future_covariates is not None: + output_size = num_output_patches * self.model_output_patch_size + + if output_size > future_covariates.shape[1]: + batch_size = len(future_covariates) + padding_size = output_size - future_covariates.shape[1] + padding_tensor = torch.full( + (batch_size, padding_size), + fill_value=torch.nan, + device=future_covariates.device, + ) + future_covariates = torch.cat( + [future_covariates, padding_tensor], dim=1 + ) + + else: + future_covariates = future_covariates[..., :output_size] + kwargs["future_covariates"] = future_covariates + with torch.no_grad(): + prediction: torch.Tensor = self.model( + context=context, + group_ids=group_ids, + num_output_patches=num_output_patches, + **kwargs, + ).quantile_preds.to(context) + + return prediction + + def postprocess(self, output: torch.Tensor): + return output[0].mean(dim=1, keepdim=True) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/utils.py new file mode 100644 index 0000000000000..279652abba95a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/utils.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List + +import torch +from einops import repeat + + +def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: + max_len = max(len(c) for c in tensors) + padded = [] + for c in tensors: + assert isinstance(c, torch.Tensor) + assert c.ndim == 1 + padding = torch.full( + size=(max_len - len(c),), fill_value=torch.nan, device=c.device + ) + padded.append(torch.concat((padding, c), dim=-1)) + return torch.stack(padded) + + +def interpolate_quantiles( + query_quantile_levels: torch.Tensor | list[float], + original_quantile_levels: torch.Tensor | list[float], + original_values: torch.Tensor, +) -> torch.Tensor: + """ + Interpolates quantile values at specified query levels using linear interpolation using original + quantile levels and their corresponding values. This behaves similar to `torch.quantile` in terms of + the linear interpolation but also supports non-equidistant original quantile levels. + + Parameters + ---------- + query_quantile_levels : torch.Tensor | list[float] + The quantile levels at which to interpolate values, all levels must be between 0 and 1 + original_quantile_levels : torch.Tensor | list[float] + The quantile levels corresponding to the original values, all levels must be between 0 and 1. + Can be a 1D tensor or list matching the last dimension of `original_values`, or a tensor with the + same shape as `original_values` + original_values : torch.Tensor + The values corresponding to the original quantile levels, can have any number of leading dimensions + + Returns + ------- + torch.Tensor + The interpolated quantiles at the query quantile levels. All leading dimensions have the same size + as `original_values` and the last dimension has size `len(query_quantile_levels)`. + """ + assert torch.is_floating_point( + original_values + ), "`original_values` must be a floating point tensor" + orig_dtype = original_values.dtype + if isinstance(query_quantile_levels, list): + query_quantile_levels = torch.tensor(query_quantile_levels, dtype=torch.float32) + if isinstance(original_quantile_levels, list): + original_quantile_levels = torch.tensor( + original_quantile_levels, dtype=torch.float32 + ) + + assert ( + query_quantile_levels.ndim == 1 + ), "`query_quantile_levels` must be 1-dimensional" + if original_quantile_levels.ndim > 1: + assert ( + original_quantile_levels.shape == original_values.shape + ), "If `original_quantile_levels` is not 1D, its shape must match `original_values`" + else: + assert ( + len(original_quantile_levels) == original_values.shape[-1] + ), "If `original_quantile_levels` is 1D, its length must match the last dim of `original_values`" + assert ( + query_quantile_levels.min() >= 0.0 and query_quantile_levels.max() <= 1.0 + ), "`query_quantile_levels` must be between 0 and 1" + assert ( + original_quantile_levels.min() >= 0.0 and original_quantile_levels.max() <= 1.0 + ), "`original_quantile_levels` must be between 0 and 1" + original_quantile_levels = torch.clamp(original_quantile_levels, min=0.0, max=1.0) + + device = original_values.device + query_quantile_levels = query_quantile_levels.to(device) + original_quantile_levels = original_quantile_levels.to(device) + original_values = original_values.to(torch.float32) + + orig_values_shape = original_values.shape + num_original_quantiles = original_quantile_levels.shape[-1] + original_values = original_values.reshape(-1, num_original_quantiles) + batch_size = original_values.shape[0] + + # If original_quantile_levels is 1D, expand it to match the batch dimension + if original_quantile_levels.ndim == 1: + original_quantile_levels = original_quantile_levels.expand(batch_size, -1) + else: + original_quantile_levels = original_quantile_levels.reshape( + -1, num_original_quantiles + ) + + # Sort original quantile levels and the corresponding values + sorted_levels, sorted_indices = torch.sort(original_quantile_levels, dim=-1) + sorted_values = torch.gather(original_values, dim=-1, index=sorted_indices) + + # Add extreme quantiles (0., 1.) to handle extrapolation and queries at 0 or 1 + zeros_padding = torch.zeros((batch_size, 1), dtype=torch.float32, device=device) + ones_padding = torch.ones((batch_size, 1), dtype=torch.float32, device=device) + + # Only pad when extreme quantiles are not available in original_quantile_levels + sorted_levels_with_padding = [] + sorted_values_with_padding = [] + if original_quantile_levels.min() > 0.0: + sorted_levels_with_padding.append(zeros_padding) + sorted_values_with_padding.append(sorted_values[:, :1]) + sorted_levels_with_padding.append(sorted_levels) + sorted_values_with_padding.append(sorted_values) + if original_quantile_levels.max() < 1.0: + sorted_levels_with_padding.append(ones_padding) + sorted_values_with_padding.append(sorted_values[:, -1:]) + + sorted_levels = torch.cat(sorted_levels_with_padding, dim=-1).contiguous() + sorted_values = torch.cat(sorted_values_with_padding, dim=-1) + + # Shape goes from (num_queries,) to (batch_size, num_queries). + query_levels_expanded = repeat( + query_quantile_levels, "q -> b q", b=batch_size + ).contiguous() + + # Find (sorted) index of smallest original quantile level strictly larger than the query quantile level + upper_indices = torch.searchsorted(sorted_levels, query_levels_expanded, right=True) + upper_indices = torch.clamp(upper_indices, max=sorted_levels.shape[-1] - 1) + lower_indices = upper_indices - 1 + + # Gather the lower and upper levels and values for each item in the batch + lower_levels = torch.gather(sorted_levels, dim=1, index=lower_indices) + upper_levels = torch.gather(sorted_levels, dim=1, index=upper_indices) + lower_values = torch.gather(sorted_values, dim=1, index=lower_indices) + upper_values = torch.gather(sorted_values, dim=1, index=upper_indices) + + # Perform linear interpolation + level_diff = upper_levels - lower_levels + weight = torch.nan_to_num( + (query_levels_expanded - lower_levels) / level_diff, nan=0.0 + ) + interpolated_values = lower_values + weight * (upper_values - lower_values) + + final_shape = (*orig_values_shape[:-1], len(query_quantile_levels)) + return interpolated_values.reshape(final_shape).to(orig_dtype) + + +def weighted_quantile( + query_quantile_levels: torch.Tensor | list[float], + sample_weights: torch.Tensor | list[float], + samples: torch.Tensor, +): + """ + Computes quantiles from a distribution specified by `samples` and their corresponding probability mass + `sample_weights`. `samples` are first sorted along the last axis and an empirical cumulative distribution + function (CDF) is constructed. Specific `query_quantile_levels` are then interpolated using this CDF. + + Parameters + ---------- + query_quantile_levels : torch.Tensor | list[float] + The quantile levels to interpolate from the empirical CDF, must be between 0 and 1 + sample_weights : torch.Tensor | list[float] + The weights corresponding to each sample, must be non-negative. The sample_weights correspond to the + last axis of `samples` and all leading batch dimensions share the same sample weights + samples : torch.Tensor + The sample values used to construct the empirical CDF along the last axis. The last dim must + match the length of `sample_weights`, can have any number of leading dimensions + + Returns + ------- + torch.Tensor + The interpolated quantiles at the query quantile levels. All leading dimensions have the same size + as `samples` and the last dimension has size `len(query_quantile_levels)`. + """ + # FIXME: this interpolation works reasonably well in practice but may not be the best way to extrapolate + assert torch.is_floating_point( + samples + ), "`original_values` must be a floating point tensor" + orig_dtype = samples.dtype + if isinstance(query_quantile_levels, list): + query_quantile_levels = torch.tensor(query_quantile_levels, dtype=torch.float32) + if isinstance(sample_weights, list): + sample_weights = torch.tensor(sample_weights, dtype=torch.float32) + + assert ( + query_quantile_levels.ndim == 1 and sample_weights.ndim == 1 + ), "`query_quantile_levels` and `sample_weights` must be 1-dimensional" + assert ( + len(sample_weights) == samples.shape[-1] + ), "the last dim of `samples` must be equal to the length of `sample_weights`" + assert ( + query_quantile_levels.min() >= 0.0 and query_quantile_levels.max() <= 1.0 + ), "`query_quantile_levels` must be between 0 and 1" + assert sample_weights.min() > 0.0, "`sample_weights` must be > 0" + + device = samples.device + query_quantile_levels = query_quantile_levels.to(device) + sample_weights = sample_weights.to(device) + samples = samples.to(torch.float32) + + orig_samples_shape = samples.shape + num_samples = len(sample_weights) + samples = samples.reshape(-1, num_samples) + batch_size = samples.shape[0] + + # Normalize and expand weights + sample_weights = sample_weights / sample_weights.sum(dim=-1, keepdim=True) + sample_weights = sample_weights.expand(batch_size, -1).contiguous() + + # Sort samples and the corresponding weights + sorted_samples, sort_indices = torch.sort(samples, dim=-1) + sorted_weights = torch.gather(sample_weights, dim=-1, index=sort_indices) + + # Compute cumulative weights + cumul_weights = torch.cumsum(sorted_weights, dim=-1) + cumul_weights = torch.clamp(cumul_weights, min=0.0, max=1.0) + + # Get interpolated quantiles + interpolated_quantiles = interpolate_quantiles( + query_quantile_levels=query_quantile_levels, + original_quantile_levels=cumul_weights, + original_values=sorted_samples, + ) + + # Reshape to original shape + final_shape = (*orig_samples_shape[:-1], len(query_quantile_levels)) + return interpolated_quantiles.reshape(final_shape).to(dtype=orig_dtype) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 5d86e7c588f51..bcb4a5e2056eb 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -131,4 +131,16 @@ def __repr__(self): }, _transformers_registered=True, ), + "chronos2": ModelInfo( + model_id="chronos2", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="t5", + pipeline_cls="pipeline_chronos2.Chronos2Pipeline", + repo_id="amazon/chronos-2", + auto_map={ + "AutoConfig": "config.Chronos2ForecastingConfig", + "AutoModelForCausalLM": "model.Chronos2Model", + }, + ), }