1111# %%
1212
1313import os
14+ from typing import Any , Dict
1415
1516import flash
1617import matplotlib .pyplot as plt
@@ -196,9 +197,15 @@ def preprocess(df: pd.DataFrame, frequency: str = "1H") -> pd.DataFrame:
196197# %%
197198
198199
199- def plot_interpretation (model_path : str , predict_df : pd .DataFrame ):
200+ def plot_interpretation (model_path : str , predict_df : pd .DataFrame , parameters : Dict [ str , Any ] ):
200201 model = TabularForecaster .load_from_checkpoint (model_path )
201- predictions = model .predict (predict_df )
202+ datamodule = TabularForecastingData .from_data_frame (
203+ parameters = parameters ,
204+ predict_data_frame = predict_df ,
205+ batch_size = 256 ,
206+ )
207+ trainer = flash .Trainer (gpus = int (torch .cuda .is_available ()))
208+ predictions = trainer .predict (model , datamodule = datamodule )
202209 predictions , inputs = convert_predictions (predictions )
203210 model .pytorch_forecasting_model .plot_interpretation (inputs , predictions , idx = 0 )
204211 plt .show ()
@@ -208,7 +215,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
208215# And now we run the function to plot the trend and seasonality curves:
209216
210217# %%
211- plot_interpretation (trainer .checkpoint_callback .best_model_path , df_energy_hourly )
218+ plot_interpretation (trainer .checkpoint_callback .best_model_path , df_energy_hourly , datamodule . parameters )
212219
213220# %% [markdown]
214221# It worked! The plot shows that the `TabularForecaster` does a reasonable job of modelling the time series and also
@@ -281,7 +288,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
281288# Now let's look at what it learned:
282289
283290# %%
284- plot_interpretation (trainer .checkpoint_callback .best_model_path , df_energy_daily )
291+ plot_interpretation (trainer .checkpoint_callback .best_model_path , df_energy_daily , datamodule . parameters )
285292
286293# %% [markdown]
287294# Success! We can now also see weekly trends / seasonality uncovered by our new model.
0 commit comments