Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def predict(self,

pred_df = pred_df.rename(columns={'index': self._time_col})
if self._id_cols:
id_col_name = '-'.join(self._id_cols)
id_col_name = self._id_cols[0]
pred_df = pred_df.rename(columns={'item_id': id_col_name})
else:
pred_df = pred_df.drop(columns='item_id')
Expand Down
25 changes: 14 additions & 11 deletions runtime/databricks/automl_runtime/forecast/deepar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional
from typing import List, Optional, Union, Dict

import pandas as pd


def validate_and_generate_index(df: pd.DataFrame,
time_col: str,
frequency_unit: str,
Expand Down Expand Up @@ -66,10 +65,12 @@ def validate_and_generate_index(df: pd.DataFrame,

return new_index_full

def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
frequency_unit: str,
frequency_quantity: int,
id_cols: Optional[List[str]] = None):
def set_index_and_fill_missing_time_steps(
df: pd.DataFrame, time_col: str,
frequency_unit: str,
frequency_quantity: int,
id_cols: Optional[List[str]] = None
) -> Union[pd.DataFrame, Dict[any, pd.DataFrame]]:
"""
Transform the input dataframe to an acceptable format for the GluonTS library.

Expand All @@ -95,14 +96,16 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
valid_index = validate_and_generate_index(df=df, time_col=time_col, frequency_unit=frequency_unit, frequency_quantity=frequency_quantity)

if id_cols is not None:
if len(id_cols) > 1:
raise ValueError("DeepAR does not support multiple time series id columns")
df_dict = {}
for grouped_id, grouped_df in df.groupby(id_cols):
if isinstance(grouped_id, tuple):
ts_id = "-".join([str(x) for x in grouped_id])
else:
ts_id = str(grouped_id)
df_dict[ts_id] = (grouped_df.set_index(time_col).sort_index()
.reindex(valid_index).drop(id_cols, axis=1))
# TODO (ML-52171): Fix the DeepAR library to support multi-time series id columns
# For now, DeepAR is dropped for multiple id_cols
raise ValueError("DeepAR does not support multiple time series id columns")
df_dict[grouped_id] = (grouped_df.set_index(time_col).sort_index()
.reindex(valid_index).drop(id_cols, axis=1))

return df_dict

Expand Down
Loading