Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.

Commit c9b61fa

Browse files
authored
Fix electricity forecasting tutorial (#127)
1 parent 23e6a0b commit c9b61fa

3 files changed

Lines changed: 14 additions & 7 deletions

File tree

docs/source/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ Lightning-Sandbox documentation
1818
:caption: Start here
1919
:glob:
2020

21-
notebooks/*
2221
notebooks/**/*
2322

2423
.. raw:: html
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
title: Electricity Price Forecasting with N-BEATS
22
author: Ethan Harris (ethan@pytorchlightning.ai)
33
created: 2021-11-23
4-
updated: 2021-11-23
4+
updated: 2021-12-16
55
license: CC BY-SA
66
build: 3
77
tags:
88
- Tabular
99
- Forecasting
10+
- Timeseries
1011
description: |
1112
This tutorial covers using Lightning Flash and it's integration with PyTorch Forecasting to train an autoregressive
1213
model (N-BEATS) on hourly electricity pricing data. We show how the built-in interpretability tools from PyTorch
@@ -15,7 +16,7 @@ description: |
1516
bonus, we show hat we can resample daily observations from the data to discover weekly trends instead.
1617
requirements:
1718
- pandas==1.1.5
18-
- lightning-flash[tabular]>=0.5.2
19+
- lightning-flash[tabular]>=0.6.0
1920
accelerator:
2021
- GPU
2122
- CPU

flash_tutorials/electricity_forecasting/electricity_forecasting.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# %%
1212

1313
import os
14+
from typing import Any, Dict
1415

1516
import flash
1617
import matplotlib.pyplot as plt
@@ -196,9 +197,15 @@ def preprocess(df: pd.DataFrame, frequency: str = "1H") -> pd.DataFrame:
196197
# %%
197198

198199

199-
def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
200+
def plot_interpretation(model_path: str, predict_df: pd.DataFrame, parameters: Dict[str, Any]):
200201
model = TabularForecaster.load_from_checkpoint(model_path)
201-
predictions = model.predict(predict_df)
202+
datamodule = TabularForecastingData.from_data_frame(
203+
parameters=parameters,
204+
predict_data_frame=predict_df,
205+
batch_size=256,
206+
)
207+
trainer = flash.Trainer(gpus=int(torch.cuda.is_available()))
208+
predictions = trainer.predict(model, datamodule=datamodule)
202209
predictions, inputs = convert_predictions(predictions)
203210
model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=0)
204211
plt.show()
@@ -208,7 +215,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
208215
# And now we run the function to plot the trend and seasonality curves:
209216

210217
# %%
211-
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly)
218+
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly, datamodule.parameters)
212219

213220
# %% [markdown]
214221
# It worked! The plot shows that the `TabularForecaster` does a reasonable job of modelling the time series and also
@@ -281,7 +288,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
281288
# Now let's look at what it learned:
282289

283290
# %%
284-
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_daily)
291+
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_daily, datamodule.parameters)
285292

286293
# %% [markdown]
287294
# Success! We can now also see weekly trends / seasonality uncovered by our new model.

0 commit comments

Comments
 (0)