Skip to content

XGBoost Forecaster API

XGBoost Substation Forecaster

This package implements an XGBoost-based model to forecast power flows at NGED primary substations using numerical weather prediction (NWP) forecasts. It implements the BaseForecaster protocol defined in ml_core.

Features

  • Unified ML Interface: Implements the BaseForecaster protocol, allowing seamless integration with the Dagster orchestration pipeline.
  • Multi-NWP Support: Ingests forecasts from multiple NWP providers simultaneously. Secondary NWP features are prefixed with their model name (e.g., gfs_temperature_2m), and all NWPs are joined using a 3-hour availability delay.
  • Dynamic Seasonal Lags: Prevents lookahead bias by calculating autoregressive lags dynamically based on the forecast lead time. The model always uses the most recent available historical data for a given lead time (e.g., lag_days = max(1, ceil(lead_time_days / 7)) * 7).
  • Rigorous Backtesting: Supports simulating real-time inference via the collapse_lead_times parameter. When enabled, it filters NWP data to keep only the latest available forecast for each valid time, enforcing the 3-hour availability delay.
  • H3-based Weather Matching: Automatically matches substation coordinates to H3 resolution 5 cells used in the weather data.
  • Ensemble Averaging: Averages weather variables across ensemble members for robust feature engineering.
  • Temporal Features: Includes cyclical temporal features (sine/cosine for hour and day of year) and day of week.
  • Long-Range Horizon Handling: Supports 14-day (336h) forecasts at 30-minute resolution. The lead_time_hours is passed as a feature to the XGBoost model, allowing it to learn the decay in NWP skill over time.
  • Physical Wind Logic: Wind speed and direction are interpolated using Cartesian u and v components instead of circular interpolation. This avoids "phantom high wind" artifacts during rapid direction shifts and ensures physical correctness.

Installation

This package is part of the uv workspace. Install all dependencies from the root:

uv sync

Usage

This package is intended to be used as part of the Dagster pipeline. The XGBoostForecaster class handles the full lifecycle of the model, including training and inference.

xgboost_forecaster.config

Classes

XGBoostHyperparameters

Bases: BaseModel

Hyperparameters for the XGBoost model.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/config.py
 4
 5
 6
 7
 8
 9
10
class XGBoostHyperparameters(BaseModel):
    """Hyperparameters for the XGBoost model."""

    learning_rate: float = Field(default=0.01, gt=0.0)
    n_estimators: int = Field(default=100, gt=0)
    max_depth: int = Field(default=6, gt=0)
    enable_categorical: bool = Field(default=True)

xgboost_forecaster.data

Data loading and preprocessing for XGBoost forecasting.

Classes

DataConfig dataclass

Configuration for data loading and preprocessing.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/data.py
27
28
29
30
31
32
33
34
35
36
@dataclasses.dataclass
class DataConfig:
    """Configuration for data loading and preprocessing."""

    base_power_path: Path = _SETTINGS.nged_data_path / "delta" / "live_primary_flows"
    base_weather_path: Path = _SETTINGS.nwp_data_path / "ECMWF" / "ENS"
    # Resolution 5 is the fixed standard for this model as it balances spatial
    # precision with feature dimensionality for the XGBoost model.
    h3_res: int = 5
    resolution: str = "30m"

Functions

construct_historical_weather(start_date, end_date, h3_indices, config=None)

Construct a continuous historical weather timeseries by stitching NWP runs.

Parameters:

Name Type Description Default
start_date date

Start date for the timeseries.

required
end_date date

End date for the timeseries.

required
h3_indices list[int]

List of H3 indices to filter for.

required
config DataConfig | None

Data configuration.

None

Returns:

Type Description
DataFrame[Nwp]

A Patito DataFrame containing the stitched NWP data.

Raises:

Type Description
RuntimeError

If no NWP files are found in the date range.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/data.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def construct_historical_weather(
    start_date: date, end_date: date, h3_indices: list[int], config: DataConfig | None = None
) -> pt.DataFrame[Nwp]:
    """Construct a continuous historical weather timeseries by stitching NWP runs.

    Args:
        start_date: Start date for the timeseries.
        end_date: End date for the timeseries.
        h3_indices: List of H3 indices to filter for.
        config: Data configuration.

    Returns:
        A Patito DataFrame containing the stitched NWP data.

    Raises:
        RuntimeError: If no NWP files are found in the date range.
    """
    config = config or DataConfig()
    # We expect NWP files to follow the `YYYY-MM-DDTHHZ.parquet` naming contract.
    # We use a robust parsing mechanism to ignore files that do not match this
    # format, preventing crashes from unexpected files in the directory.
    files = sorted(config.base_weather_path.glob("*.parquet"))
    relevant_files = []
    for f in files:
        try:
            file_date = datetime.strptime(f.stem[:10], "%Y-%m-%d").date()
            if start_date <= file_date <= end_date:
                relevant_files.append(f)
        except ValueError:
            # Skip files that don't match the expected date format
            continue

    if not relevant_files:
        raise RuntimeError(f"No NWP files found between {start_date} and {end_date}")

    weather_dfs = []
    for f in relevant_files:
        df = pl.read_parquet(f)
        df = df.filter(pl.col("h3_index").is_in(h3_indices))

        if not df.is_empty():
            weather_dfs.append(df)

    if not weather_dfs:
        raise RuntimeError(f"No weather data found for H3 indices in range {start_date}-{end_date}")

    # Combine all forecasts to retain a distribution of lead times
    combined = pl.concat(weather_dfs, how="diagonal")
    combined = combined.sort(["h3_index", "ensemble_member", "valid_time", "init_time"])

    # Descale data immediately to physical units (Float32)
    params = load_scaling_params()
    scaling_cols = params.select("col_name").to_series().to_list()

    # Only descale columns that are actually UInt8 to prevent double-descaling
    schema = combined.schema
    uint8_cols = [col for col, dtype in schema.items() if dtype == pl.UInt8 and col in scaling_cols]

    if uint8_cols:
        descale_exprs = uint8_to_physical_unit(params.filter(pl.col("col_name").is_in(uint8_cols)))
        combined = combined.with_columns(descale_exprs)

    return Nwp.validate(combined)

get_substation_metadata(config=None)

Load substation metadata and filter for those with available power data.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/data.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_substation_metadata(config: DataConfig | None = None) -> pt.DataFrame[SubstationMetadata]:
    """Load substation metadata and filter for those with available power data."""
    config = config or DataConfig()
    metadata_path = _SETTINGS.nged_data_path / "parquet" / "substation_metadata.parquet"
    metadata_df = SubstationMetadata.validate(pl.read_parquet(metadata_path))

    # Only return substations we have local power data for in Delta Lake.
    # We use `scan_delta_table` to perform a lazy, optimized scan of the Delta Lake
    # table, which is more memory-efficient than `read_delta` for large tables.
    # This helper also ensures the timestamp column is UTC-aware.
    substations_with_telemetry = (
        cast(
            pl.DataFrame,
            scan_delta_table(config.base_power_path).select("substation_number").unique().collect(),
        )
        .to_series()
        .to_list()
    )

    return metadata_df.filter(pl.col("substation_number").is_in(substations_with_telemetry))

load_nwp_run(init_time, h3_indices, config=None)

Load a single NWP forecast run.

Parameters:

Name Type Description Default
init_time datetime

The initialization time of the NWP run.

required
h3_indices list[int]

List of H3 indices to filter for.

required
config DataConfig | None

Data configuration.

None

Returns:

Type Description
DataFrame[Nwp]

A Patito DataFrame containing the NWP data.

Raises:

Type Description
FileNotFoundError

If the NWP file for the given init_time does not exist.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/data.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def load_nwp_run(
    init_time: datetime, h3_indices: list[int], config: DataConfig | None = None
) -> pt.DataFrame[Nwp]:
    """Load a single NWP forecast run.

    Args:
        init_time: The initialization time of the NWP run.
        h3_indices: List of H3 indices to filter for.
        config: Data configuration.

    Returns:
        A Patito DataFrame containing the NWP data.

    Raises:
        FileNotFoundError: If the NWP file for the given init_time does not exist.
    """
    config = config or DataConfig()
    # The filename format (`YYYY-MM-DDTHHZ.parquet`) is a strict contract with
    # the upstream data pipeline.
    filename = f"{init_time.strftime('%Y-%m-%dT%H')}Z.parquet"
    file_path = config.base_weather_path / filename

    try:
        df = pl.read_parquet(file_path)
    except FileNotFoundError:
        raise FileNotFoundError(
            f"NWP file not found at {file_path}. Expected format: YYYY-MM-DDTHHZ.parquet"
        )

    df = df.filter(pl.col("h3_index").is_in(h3_indices))

    if df.is_empty():
        raise RuntimeError(f"No data found for requested H3 indices in {file_path}")

    # Descale data immediately to physical units (Float32)
    params = load_scaling_params()
    scaling_cols = params.select("col_name").to_series().to_list()

    # Only descale columns that are actually UInt8 to prevent double-descaling
    schema = df.schema
    uint8_cols = [col for col, dtype in schema.items() if dtype == pl.UInt8 and col in scaling_cols]

    if uint8_cols:
        descale_exprs = uint8_to_physical_unit(params.filter(pl.col("col_name").is_in(uint8_cols)))
        df = df.with_columns(descale_exprs)

    return Nwp.validate(df)

process_nwp_data(nwp, h3_indices)

Process NWP data: lead-time filtering and 30m interpolation for all members.

Note: Accumulated variables (e.g., precipitation, radiation) are already de-accumulated by Dynamical.org prior to download, and should not be differenced.

Parameters:

Name Type Description Default
nwp LazyFrame

Raw NWP data.

required
h3_indices list[int]

List of H3 indices to filter for.

required

Returns:

Type Description
LazyFrame

Processed NWP data.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/data.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def process_nwp_data(
    nwp: pl.LazyFrame,
    h3_indices: list[int],
) -> pl.LazyFrame:
    """Process NWP data: lead-time filtering and 30m interpolation for all members.

    Note: Accumulated variables (e.g., precipitation, radiation) are already
    de-accumulated by Dynamical.org prior to download, and should not be differenced.

    Args:
        nwp: Raw NWP data.
        h3_indices: List of H3 indices to filter for.

    Returns:
        Processed NWP data.
    """
    # 1. Filter by H3 indices to reduce data size
    lf = nwp.filter(pl.col("h3_index").is_in(h3_indices))

    # Descale data immediately to physical units (Float32) before any interpolation
    # or feature engineering. This ensures that interpolation happens in the
    # physical space and prevents mixing scaled (UInt8) and unscaled variables
    # in downstream calculations (like windchill or wind speed).
    params = load_scaling_params()
    scaling_cols = params.select("col_name").to_series().to_list()

    # Only descale columns that are actually UInt8 to prevent double-descaling
    # (e.g., if the function is called with already-descaled data in tests).
    schema = lf.collect_schema()
    uint8_cols = [col for col, dtype in schema.items() if dtype == pl.UInt8 and col in scaling_cols]

    if uint8_cols:
        descale_exprs = uint8_to_physical_unit(params.filter(pl.col("col_name").is_in(uint8_cols)))
        lf = lf.with_columns(descale_exprs)

    # 2. Calculate Lead Time and Filter (Fixing Leakage)
    # We cap the lead time at 336 hours (14 days) because ECMWF ENS
    # reliability drops significantly after day 14, and the model is only
    # validated for a 14-day horizon.
    lf = (
        lf.with_columns(
            lead_time_hours=(pl.col("valid_time") - pl.col("init_time")).dt.total_minutes() / 60.0
        )
        .with_columns(pl.col("lead_time_hours").cast(pl.Float32))
        .filter(pl.col("lead_time_hours") <= 336)
    )

    # 3. Interpolation (Fixing Nulls)
    # Ensure each group has at least two points for interpolation.
    # Groups with only 1 row cannot be interpolated and would violate the
    # 30-minute temporal resolution contract.
    lf = lf.filter(pl.len().over(["init_time", "h3_index", "ensemble_member"]) > 1)

    # Create a complete 30-minute time grid for each (init_time, h3_index, ensemble_member).
    # This replaces the eager `upsample` with a native lazy operation.
    grid = (
        lf.select(["init_time", "h3_index", "ensemble_member", "valid_time"])
        .group_by(["init_time", "h3_index", "ensemble_member"])
        .agg(
            [
                pl.col("valid_time").min().alias("start"),
                pl.col("valid_time").max().alias("end"),
            ]
        )
        .with_columns(valid_time=pl.datetime_ranges("start", "end", interval="30m"))
        .explode("valid_time")
        .drop(["start", "end"])
    )

    # Join the grid with the original data to create gaps for interpolation
    lf = grid.join(lf, on=["init_time", "h3_index", "ensemble_member", "valid_time"], how="left")

    # Identify columns for interpolation and forward-fill
    categorical_cols = ["categorical_precipitation_type_surface"]
    schema = lf.collect_schema()
    exclude_cols = ["valid_time", "h3_index", "ensemble_member", "init_time", "lead_time_hours"]
    numeric_cols = [
        col
        for col, dtype in schema.items()
        if dtype.is_numeric() and col not in exclude_cols and col not in categorical_cols
    ]

    # TEMPORAL INTERPOLATION & LEAKAGE:
    # Interpolating over `valid_time` within a single `init_time` is NOT data leakage.
    # All `valid_time` predictions in a single forecast run are generated simultaneously
    # at `init_time`. We are not looking into the future of when the forecast was made,
    # but merely interpolating the forecast's own future predictions to a higher
    # temporal resolution (30m).
    lf = lf.with_columns(
        [
            pl.col(c).interpolate().over(["init_time", "h3_index", "ensemble_member"])
            for c in numeric_cols
        ]
        + [
            # CATEGORICAL FORWARD-FILL:
            # Linear interpolation is physically meaningless for categorical variables.
            # We use forward-fill to maintain the discrete state of the weather condition.
            pl.col(c).forward_fill().over(["init_time", "h3_index", "ensemble_member"])
            for c in categorical_cols
            if c in schema.names()
        ]
    )

    # PHYSICAL WIND CALCULATION (Lazy):
    # After interpolating U and V components, we calculate physical wind speed
    # and direction. This ensures the circular topology of wind direction is
    # preserved without needing complex circular interpolation logic.
    wind_cols = ["wind_u_10m", "wind_v_10m", "wind_u_100m", "wind_v_100m"]
    if all(c in lf.collect_schema().names() for c in wind_cols):
        lf = lf.with_columns(
            [
                (pl.col("wind_u_10m") ** 2 + pl.col("wind_v_10m") ** 2)
                .sqrt()
                .alias("wind_speed_10m"),
                ((pl.arctan2("wind_u_10m", "wind_v_10m") * 180 / math.pi + 180) % 360).alias(
                    "wind_direction_10m"
                ),
                (pl.col("wind_u_100m") ** 2 + pl.col("wind_v_100m") ** 2)
                .sqrt()
                .alias("wind_speed_100m"),
                ((pl.arctan2("wind_u_100m", "wind_v_100m") * 180 / math.pi + 180) % 360).alias(
                    "wind_direction_100m"
                ),
            ]
        ).drop(wind_cols)

    # Recalculate lead_time_hours for the new 30m timestamps
    lf = lf.with_columns(
        lead_time_hours=(
            (pl.col("valid_time") - pl.col("init_time")).dt.total_minutes() / 60.0
        ).cast(pl.Float32)
    )

    # Final cast to Float32 for all physical variables to satisfy data contracts.
    # We exclude H3 index and ensemble member as they are identifiers, and
    # categorical columns which must remain as integers.
    schema = lf.collect_schema()
    physical_cols = [
        col
        for col, dtype in schema.items()
        if dtype.is_numeric()
        and col not in ["h3_index", "ensemble_member"]
        and col not in categorical_cols
    ]
    lf = lf.with_columns([pl.col(c).cast(pl.Float32) for c in physical_cols])

    return lf

xgboost_forecaster.features

Feature engineering for XGBoost forecasting.

Classes

Functions

add_autoregressive_lags(df, flows_30m, telemetry_delay_hours=24)

Add autoregressive lags to the feature matrix.

This function calculates the required lag dynamically to strictly prevent lookahead bias, ensuring that the model only uses power flow data that would have been available at the time the forecast was made.

Parameters:

Name Type Description Default
df LazyFrame

The input LazyFrame (schema: SubstationFeatures).

required
flows_30m LazyFrame

Historical power flows downsampled to 30m.

required
telemetry_delay_hours int

Delay in hours for telemetry availability.

24

Returns:

Type Description
LazyFrame

LazyFrame with added lag features (schema: SubstationFeatures).

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/features.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def add_autoregressive_lags(
    df: pl.LazyFrame, flows_30m: pl.LazyFrame, telemetry_delay_hours: int = 24
) -> pl.LazyFrame:
    """Add autoregressive lags to the feature matrix.

    This function calculates the required lag dynamically to strictly prevent
    lookahead bias, ensuring that the model only uses power flow data that
    would have been available at the time the forecast was made.

    Args:
        df: The input LazyFrame (schema: SubstationFeatures).
        flows_30m: Historical power flows downsampled to 30m.
        telemetry_delay_hours: Delay in hours for telemetry availability.

    Returns:
        LazyFrame with added lag features (schema: SubstationFeatures).
    """
    # 1. Calculate the required lag dynamically to strictly prevent lookahead bias
    df = (
        df.with_columns(
            lead_time_days=(pl.col("valid_time") - pl.col("init_time")).dt.total_seconds()
            / (24 * 3600)
        )
        .with_columns(
            lag_days=pl.max_horizontal(
                pl.lit(1),
                ((pl.col("lead_time_days") + telemetry_delay_hours / 24.0) / 7.0)
                .ceil()
                .cast(pl.Int32),
            )
            * 7
        )
        .with_columns(
            target_lag_time=pl.col("valid_time") - pl.duration(days=1) * pl.col("lag_days")
        )
    )

    # 2. Join flows_30m on ["substation_number", "target_lag_time"] to extract the exact
    # latest_available_weekly_power_lag without needing pre-calculated lag_7d or lag_14d columns.
    lag_df = flows_30m.select(
        pl.col("substation_number"),
        pl.col("timestamp").alias("target_lag_time"),
        pl.col("MW_or_MVA").alias("latest_available_weekly_power_lag"),
    )

    df = df.join(lag_df, on=["substation_number", "target_lag_time"], how="left")

    return df

add_time_features(df)

Add lead_time_hours and nwp_init_hour features.

Parameters:

Name Type Description Default
df LazyFrame

The input LazyFrame (schema: SubstationFeatures).

required

Returns:

Type Description
LazyFrame

LazyFrame with added time features (schema: SubstationFeatures).

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/features.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def add_time_features(df: pl.LazyFrame) -> pl.LazyFrame:
    """Add lead_time_hours and nwp_init_hour features.

    Args:
        df: The input LazyFrame (schema: SubstationFeatures).

    Returns:
        LazyFrame with added time features (schema: SubstationFeatures).
    """

    return df.with_columns(
        lead_time_hours=(
            pl.col(NwpColumns.VALID_TIME) - pl.col(NwpColumns.INIT_TIME)
        ).dt.total_minutes()
        / 60.0,
        nwp_init_hour=pl.col(NwpColumns.INIT_TIME).dt.hour().cast(pl.Int32),
    )

add_weather_features(weather, history=None)

Add lags and trends to weather data.

Parameters:

Name Type Description Default
weather LazyFrame

Current weather forecast (schema: ProcessedNwp).

required
history LazyFrame | None

Historical weather data (optional, used for lags).

None

Returns:

Type Description
LazyFrame

LazyFrame with added weather features (schema: ProcessedNwp).

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/features.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def add_weather_features(
    weather: pl.LazyFrame, history: pl.LazyFrame | None = None
) -> pl.LazyFrame:
    """Add lags and trends to weather data.

    Args:
        weather: Current weather forecast (schema: ProcessedNwp).
        history: Historical weather data (optional, used for lags).

    Returns:
        LazyFrame with added weather features (schema: ProcessedNwp).
    """
    schema_names = weather.collect_schema().names()
    if NwpColumns.TEMPERATURE_2M not in schema_names:
        raise ValueError(
            f"Required weather column '{NwpColumns.TEMPERATURE_2M}' is missing from the input "
            f"LazyFrame. Available columns: {schema_names}"
        )

    # Add windchill if both temperature and wind speed are present
    if NwpColumns.TEMPERATURE_2M in schema_names and NwpColumns.WIND_SPEED_10M in schema_names:
        v_kmh = pl.col(NwpColumns.WIND_SPEED_10M) * 3.6
        temp = pl.col(NwpColumns.TEMPERATURE_2M)

        weather = weather.with_columns(
            windchill=(
                13.12 + 0.6215 * temp - 11.37 * (v_kmh**0.16) + 0.3965 * temp * (v_kmh**0.16)
            ).cast(pl.Float32)
        )

    weather = weather.sort(NwpColumns.INIT_TIME)
    if history is not None:
        full_weather = pl.concat([history, weather], how="diagonal").sort(NwpColumns.INIT_TIME)
    else:
        full_weather = weather

    # Check if we have enough history for the 14-day lag
    sorted_weather = full_weather.sort(NwpColumns.VALID_TIME)
    min_time_df = sorted_weather.head(1).select(NwpColumns.VALID_TIME).collect()
    max_time_df = sorted_weather.tail(1).select(NwpColumns.VALID_TIME).collect()

    if (
        isinstance(min_time_df, pl.DataFrame)
        and isinstance(max_time_df, pl.DataFrame)
        and min_time_df.height > 0
        and max_time_df.height > 0
    ):
        min_time = min_time_df.item(0, NwpColumns.VALID_TIME)
        max_time = max_time_df.item(0, NwpColumns.VALID_TIME)

        if min_time is not None and max_time is not None:
            if max_time - min_time < timedelta(days=14):
                log.warning(
                    "Provided weather data does not cover the required 14-day lag range. "
                    "Weather lag features will be null for the first 14 days of the forecast."
                )

    def _add_lag_asof(
        df: pl.LazyFrame, source_df: pl.LazyFrame, offset: timedelta, suffix: str
    ) -> pl.LazyFrame:
        """Add a lagged weather feature using an exact point-in-time join.

        This function uses a backward `join_asof` on `init_time` with `target_valid_time`
        to fetch historical weather forecasts without lookahead bias. It simulates the exact
        knowledge state at `init_time` by finding the most recent forecast that was valid
        at `target_valid_time` (which is `valid_time - offset`).

        This prevents data leakage by ensuring the model only sees weather forecasts that
        were actually available at the time the forecast was made.
        """
        by_cols = ["target_valid_time"]
        if NwpColumns.H3_INDEX in schema_names:
            by_cols.append(NwpColumns.H3_INDEX)

        source_schema = source_df.collect_schema().names()
        actual_temp_col = NwpColumns.TEMPERATURE_2M
        actual_sw_col = NwpColumns.SW_RADIATION

        source_cols = [NwpColumns.VALID_TIME, NwpColumns.INIT_TIME, actual_temp_col]
        if NwpColumns.H3_INDEX in schema_names:
            source_cols.append(NwpColumns.H3_INDEX)
        if actual_sw_col in source_schema:
            source_cols.append(actual_sw_col)

        # Filter source_df to only include ensemble_member == 0 if it exists.
        # This ensures that lag features are consistent across all ensemble members
        # and reduces the memory footprint of the join.
        if NwpColumns.ENSEMBLE_MEMBER in source_schema:
            source_df = source_df.filter(pl.col(NwpColumns.ENSEMBLE_MEMBER) == 0)

        right = source_df.select(source_cols).rename(
            {
                NwpColumns.VALID_TIME: "target_valid_time",
                actual_temp_col: f"{NwpColumns.TEMPERATURE_2M}_{suffix}",
            }
        )
        if actual_sw_col in source_schema:
            right = right.rename({actual_sw_col: f"{NwpColumns.SW_RADIATION}_{suffix}"})

        left = df.with_columns(target_valid_time=pl.col(NwpColumns.VALID_TIME) - offset)

        # We explicitly sort by the group keys (by_cols) and the join key (init_time)
        # to ensure the data is correctly ordered for the asof join.
        # We pass check_sortedness=False to suppress a false-positive warning from Polars
        # that can occur even when the data is correctly sorted.
        return (
            left.sort(by_cols + [NwpColumns.INIT_TIME])
            .join_asof(
                right.sort(by_cols + [NwpColumns.INIT_TIME]),
                on=NwpColumns.INIT_TIME,
                by=by_cols,
                strategy="backward",
                check_sortedness=False,
            )
            .drop("target_valid_time")
        )

    weather = _add_lag_asof(weather, full_weather, timedelta(days=7), "lag_7d")
    weather = _add_lag_asof(weather, full_weather, timedelta(days=14), "lag_14d")
    weather = _add_lag_asof(weather, full_weather, timedelta(hours=6), "6h_ago")

    return weather.with_columns(
        temp_trend_6h=(
            pl.col(NwpColumns.TEMPERATURE_2M).cast(pl.Float32)
            - pl.col(f"{NwpColumns.TEMPERATURE_2M}_6h_ago").cast(pl.Float32)
        ).cast(pl.Float32)
    )

xgboost_forecaster.model

XGBoost implementation of the Forecaster interface.

Classes

XGBoostForecaster

Bases: BaseForecaster

XGBoost implementation of the Forecaster interface.

This class handles the full lifecycle of an XGBoost model: training and production inference.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/model.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
class XGBoostForecaster(BaseForecaster):
    """XGBoost implementation of the Forecaster interface.

    This class handles the full lifecycle of an XGBoost model: training
    and production inference.
    """

    model: XGBRegressor | None
    target_map: pl.DataFrame | pl.LazyFrame | None
    config: ModelConfig
    feature_names: list[str]

    def __init__(self, model: XGBRegressor | None = None):
        """Initialize the forecaster.

        Args:
            model: An optional pre-trained XGBoost model.
        """
        self.model = model
        self.target_map = None
        self.feature_names = []

    def _get_target_map_df(self) -> pl.DataFrame:
        """Get the target map as a Polars DataFrame.

        Returns:
            The target map collected into a DataFrame.

        Raises:
            ValueError: If target_map is not set.
        """
        if self.target_map is None:
            raise ValueError("target_map must be set before calling this method.")

        if isinstance(self.target_map, pl.LazyFrame):
            return cast(pl.DataFrame, self.target_map.collect())
        return cast(pl.DataFrame, self.target_map)

    def log_model(self, model_name: str) -> None:
        """Log the model to MLflow."""
        if self.model is not None:
            mlflow.xgboost.log_model(self.model, artifact_path="model")

        if self.target_map is not None:
            # Ensure target_map is a DataFrame before writing to JSON
            target_map_df = self._get_target_map_df()
            with tempfile.TemporaryDirectory() as tmpdir:
                path = os.path.join(tmpdir, "target_map.json")
                target_map_df.write_json(path)
                mlflow.log_artifact(path, artifact_path="metadata")

    def _prepare_features(self, df: pl.DataFrame | pl.LazyFrame) -> pl.DataFrame | pl.LazyFrame:
        """Extract the feature matrix.

        Args:
            df: The joined input data with all features already added.

        Returns:
            A Polars DataFrame or LazyFrame containing only the feature columns.
        """
        # Extract feature matrix using explicit feature names.
        # We prioritize self.feature_names (set during training) over the config
        # to ensure consistency between training and inference.
        if self.feature_names:
            feature_names = [f for f in self.feature_names if f != "MW_or_MVA"]
        elif hasattr(self, "config") and self.config.features.feature_names:
            feature_names = [f for f in self.config.features.feature_names if f != "MW_or_MVA"]
        else:
            raise ValueError(
                "Feature names must be explicitly provided either in the config or "
                "during training. Fallback feature selection is disabled to prevent "
                "silent model degradation."
            )

        res = df.select(feature_names)

        if isinstance(res, pl.LazyFrame):
            res_schema = res.collect_schema()
        else:
            res_schema = res.schema  # Fail loudly if columns are missing

        # Ensure substation_number is treated as a categorical feature by XGBoost
        if "substation_number" in res_schema.names():
            res = res.with_columns(pl.col("substation_number").cast(pl.String).cast(pl.Categorical))

        # Ensure categorical precipitation is treated as a categorical feature
        if "categorical_precipitation_type_surface" in res_schema.names():
            res = res.with_columns(
                pl.col("categorical_precipitation_type_surface")
                .cast(pl.String)
                .cast(pl.Categorical)
            )

        return res

    def _collapse_lead_times(
        self,
        df: pl.LazyFrame,
        nwp_cutoff: datetime,
        delay_hours: int,
    ) -> pl.LazyFrame:
        """Collapse lead times to keep only the latest available forecast for each valid time.

        This simulates real-time inference by ensuring that for any given valid_time,
        we only use the most recent NWP forecast that would have been available
        at the nwp_cutoff time.

        Args:
            df: The input LazyFrame containing NWP data.
            nwp_cutoff: The time at which we are making the forecast.
            delay_hours: The availability delay of the NWP data.

        Returns:
            A LazyFrame with collapsed lead times.
        """
        return (
            df.filter(pl.col(NwpColumns.INIT_TIME) + pl.duration(hours=delay_hours) <= nwp_cutoff)
            .sort(NwpColumns.INIT_TIME)
            .group_by(
                [
                    NwpColumns.VALID_TIME,
                    NwpColumns.H3_INDEX,
                    NwpColumns.ENSEMBLE_MEMBER,
                ]
            )
            .last()
        )

    def _prepare_and_join_nwps(
        self,
        nwps: Mapping[NwpModel, pl.LazyFrame],
        nwp_cutoff: datetime | None = None,
        collapse_lead_times: bool = False,
    ) -> pl.LazyFrame:
        """Prepare and join multiple NWP sources."""
        # Get delay from config, default to 3 hours if not set
        delay_hours = self.config.nwp_availability_delay_hours if hasattr(self, "config") else 3

        nwp_list = []
        for i, (name, lf) in enumerate(nwps.items()):
            # Add weather features (lags, trends, etc.)
            lf = add_weather_features(lf)
            # Add time features (lead_time_hours, nwp_init_hour)
            lf = add_time_features(lf)

            # CRITICAL: This filter is the sole mechanism preventing future data leakage.
            # It ensures that we only use NWP forecasts that would have been available
            # at the valid time, accounting for the data ingestion delay.
            # init_time + delay_hours <= valid_time
            available_nwp = lf.filter(
                pl.col(NwpColumns.INIT_TIME) + pl.duration(hours=delay_hours)
                <= pl.col(NwpColumns.VALID_TIME)
            )

            if nwp_cutoff is None:
                # During training, we use all available lead times.
                # The target_horizon_hours filter is removed to maximize training data.

                # For the primary NWP model (i == 0), we only use the control member (0)
                # to avoid inflating the training set with highly correlated ensemble members.
                if i == 0:
                    available_nwp = available_nwp.filter(pl.col(NwpColumns.ENSEMBLE_MEMBER) == 0)

                latest_nwp = available_nwp
            else:
                if collapse_lead_times:
                    latest_nwp = self._collapse_lead_times(available_nwp, nwp_cutoff, delay_hours)
                else:
                    latest_nwp = available_nwp.filter(
                        pl.col(NwpColumns.INIT_TIME) + pl.duration(hours=delay_hours) <= nwp_cutoff
                    )

            if i == 0:
                prefixed_nwp = latest_nwp.with_columns(
                    available_time=pl.col(NwpColumns.INIT_TIME) + pl.duration(hours=delay_hours)
                )
            else:
                prefix = f"{name.value}_"
                # For secondary NWP models, filter to ensemble_member == 0 (control member)
                # and drop the column to avoid arbitrary pairing with primary NWP members.
                latest_nwp = latest_nwp.filter(pl.col(NwpColumns.ENSEMBLE_MEMBER) == 0).drop(
                    NwpColumns.ENSEMBLE_MEMBER
                )

                nwp_schema_names = latest_nwp.collect_schema().names()
                rename_mapping = {
                    col: f"{prefix}{col}"
                    for col in nwp_schema_names
                    if col
                    not in [
                        NwpColumns.VALID_TIME,
                        NwpColumns.H3_INDEX,
                    ]
                }
                prefixed_nwp = latest_nwp.rename(rename_mapping).with_columns(
                    available_time=pl.col(f"{prefix}{NwpColumns.INIT_TIME}")
                    + pl.duration(hours=delay_hours)
                )
            nwp_list.append(prefixed_nwp)

        combined_nwps = nwp_list[0]
        for other_nwp in nwp_list[1:]:
            # We explicitly sort by the group keys (valid_time, h3_index) and the join key (available_time)
            # to ensure the data is correctly ordered for the asof join.
            # We pass check_sortedness=False to suppress a false-positive warning from Polars
            # that can occur even when the data is correctly sorted.
            combined_nwps = (
                combined_nwps.sort([NwpColumns.VALID_TIME, NwpColumns.H3_INDEX, "available_time"])
                .join_asof(
                    other_nwp.sort([NwpColumns.VALID_TIME, NwpColumns.H3_INDEX, "available_time"]),
                    on="available_time",
                    by=[NwpColumns.VALID_TIME, NwpColumns.H3_INDEX],
                    check_sortedness=False,
                )
                .with_columns(
                    # Explicitly cast h3_index to UInt64 after the join to prevent silent type
                    # coercion (e.g., to Float64 or Int64) during Polars joins, which is
                    # critical for downstream H3 operations and memory efficiency.
                    pl.col(NwpColumns.H3_INDEX).cast(pl.UInt64)
                )
            )
        return combined_nwps

    def _prepare_training_data(
        self,
        flows_30m: pl.LazyFrame,
        metadata_lf: pl.LazyFrame,
        combined_nwps_lf: pl.LazyFrame | None = None,
    ) -> pl.LazyFrame:
        """Prepare data for training.

        For training, we start with the historical power flows and join the
        substation metadata and NWP forecasts.

        Args:
            flows_30m: Historical power flows downsampled to 30m.
            metadata_lf: Substation metadata.
            combined_nwps_lf: Combined NWP forecasts.

        Returns:
            A LazyFrame containing the joined training data.
        """
        df_lf = (
            flows_30m.rename({"timestamp": NwpColumns.VALID_TIME})
            .join(
                metadata_lf.rename({"h3_res_5": NwpColumns.H3_INDEX}),
                on="substation_number",
            )
            .with_columns(
                # Explicitly cast h3_index to UInt64 after the join to prevent silent type
                # coercion (e.g., to Float64 or Int64) during Polars joins, which is
                # critical for downstream H3 operations and memory efficiency.
                pl.col(NwpColumns.H3_INDEX).cast(pl.UInt64)
            )
        )

        if combined_nwps_lf is not None:
            df_lf = df_lf.join(
                combined_nwps_lf,
                on=[NwpColumns.VALID_TIME, NwpColumns.H3_INDEX],
                how="left",
            ).with_columns(
                # Explicitly cast h3_index to UInt64 after the join to prevent silent type
                # coercion (e.g., to Float64 or Int64) during Polars joins, which is
                # critical for downstream H3 operations and memory efficiency.
                pl.col(NwpColumns.H3_INDEX).cast(pl.UInt64)
            )
        return df_lf

    def _prepare_inference_data(
        self,
        metadata_lf: pl.LazyFrame,
        combined_nwps_lf: pl.LazyFrame,
    ) -> pl.LazyFrame:
        """Prepare data for inference.

        For inference, we start with the NWP forecasts and join the substation
        metadata.

        Args:
            metadata_lf: Substation metadata.
            combined_nwps_lf: Combined NWP forecasts.

        Returns:
            A LazyFrame containing the joined inference data.
        """
        return combined_nwps_lf.join(
            metadata_lf.rename({"h3_res_5": NwpColumns.H3_INDEX}),
            on=NwpColumns.H3_INDEX,
            how="inner",
        ).with_columns(
            # Explicitly cast h3_index to UInt64 after the join to prevent silent type
            # coercion (e.g., to Float64 or Int64) during Polars joins, which is
            # critical for downstream H3 operations and memory efficiency.
            pl.col(NwpColumns.H3_INDEX).cast(pl.UInt64)
        )

    def _prepare_data_for_model(
        self,
        flows_30m: pl.LazyFrame,
        substation_metadata: pt.DataFrame[SubstationMetadata],
        target_map_df: pl.DataFrame | pl.LazyFrame,
        nwps: Mapping[NwpModel, pl.LazyFrame] | None = None,
        inference_params: InferenceParams | None = None,
        collapse_lead_times: bool = False,
    ) -> pl.LazyFrame:
        """Prepares data for training or prediction.

        Args:
            flows_30m: Historical power flows downsampled to 30m.
            substation_metadata: Substation metadata.
            target_map_df: Target mapping dataframe.
            nwps: Dictionary of NWP data.
            inference_params: Inference parameters (only for prediction).
            collapse_lead_times: Whether to collapse lead times (only for prediction).

        Returns:
            Prepared LazyFrame ready for feature extraction.
        """
        metadata_lf = substation_metadata.select(["substation_number", "h3_res_5"]).lazy()
        target_map_lf = target_map_df.lazy()

        # 1. Prepare NWPs and join
        is_training = inference_params is None

        combined_nwps_lf = None
        if nwps:
            nwp_cutoff = inference_params.forecast_time if inference_params is not None else None
            combined_nwps_lf = self._prepare_and_join_nwps(
                nwps,
                nwp_cutoff=nwp_cutoff,
                collapse_lead_times=collapse_lead_times,
            )

        if is_training:
            df_lf = self._prepare_training_data(flows_30m, metadata_lf, combined_nwps_lf)
        else:
            if combined_nwps_lf is None:
                raise ValueError("XGBoostForecaster requires NWP data for prediction.")
            df_lf = self._prepare_inference_data(metadata_lf, combined_nwps_lf)

        # 3. Add lags and features
        # Handle missing init_time (e.g. for autoregressive-only models)
        if NwpColumns.INIT_TIME not in df_lf.collect_schema().names():
            df_lf = df_lf.with_columns(**{NwpColumns.INIT_TIME: pl.col(NwpColumns.VALID_TIME)})

        # Get telemetry delay from config, default to 24 hours if not set
        telemetry_delay_hours = self.config.telemetry_delay_hours if hasattr(self, "config") else 24
        df_lf = add_autoregressive_lags(
            df_lf, flows_30m, telemetry_delay_hours=telemetry_delay_hours
        )

        # Normalize by peak capacity
        df_lf = df_lf.join(
            target_map_lf.select(["substation_number", "peak_capacity_MW_or_MVA"]),
            on="substation_number",
            how="inner",  # FIX: Drop substations missing from target_map
        ).with_columns(
            latest_available_weekly_power_lag=pl.col("latest_available_weekly_power_lag")
            / pl.col("peak_capacity_MW_or_MVA")
        )

        if is_training:
            # For training, also normalize target
            df_lf = df_lf.with_columns(
                MW_or_MVA=pl.col("MW_or_MVA") / pl.col("peak_capacity_MW_or_MVA")
            )
        else:
            # For prediction, add dummy target for validation
            df_lf = df_lf.with_columns(MW_or_MVA=pl.lit(0.0, dtype=pl.Float32))

        df_lf = add_cyclical_temporal_features(df_lf, time_col=NwpColumns.VALID_TIME)

        # 4. Type casting
        df_lf = df_lf.with_columns(
            [
                pl.col("substation_number").cast(pl.Int32),
                pl.col("MW_or_MVA").cast(pl.Float32),
            ]
        )
        if NwpColumns.ENSEMBLE_MEMBER in df_lf.collect_schema().names():
            df_lf = df_lf.with_columns(
                pl.col(NwpColumns.ENSEMBLE_MEMBER).fill_null(0).cast(pl.UInt8)
            )

        # DATA TYPE RATIONALE:
        # Weather features are kept as Float32 in memory rather than UInt8 to:
        # 1. Preserve precision from 30-minute interpolation (avoiding "staircase" effects).
        # 2. Prevent silent underflow during feature engineering (e.g., calculating trends
        #    via subtraction).
        # 3. Align with XGBoost's native internal data type (Float32).

        # Cast all floats to Float32 for Patito
        df_lf = df_lf.with_columns(pl.col(pl.Float64).cast(pl.Float32))

        return df_lf

    def train(
        self,
        config: ModelConfig,
        flows_30m: pl.LazyFrame,
        substation_metadata: pt.DataFrame[SubstationMetadata],
        nwps: Mapping[NwpModel, pl.LazyFrame] | None = None,
    ) -> "XGBoostForecaster":
        """Train the XGBoost model.

        Args:
            config: The model configuration object.
            flows_30m: Historical power flow data downsampled to 30m.
            substation_metadata: The substation metadata containing h3 mapping.
            nwps: A dictionary of weather forecast dataframes.

        Returns:
            The trained XGBoostForecaster instance.
        """
        self.config = config
        log.info("Starting XGBoost training...")

        # Log input data info
        # Note: We don't collect the full LazyFrames here to avoid OOM,
        # just logging their presence and schema.
        log.info(f"Input flows_30m columns: {flows_30m.collect_schema().names()}")
        if nwps:
            for name, lf in nwps.items():
                log.info(f"Input NWP {name.value} columns: {lf.collect_schema().names()}")

        if len(config.features.nwps) > 0 and not nwps:
            raise ValueError("Model config requires NWPs, but none were provided.")

        if self.target_map is None:
            raise ValueError("target_map must be set before calling train.")

        joined_lf = self._prepare_data_for_model(
            flows_30m=flows_30m,
            substation_metadata=substation_metadata,
            target_map_df=self.target_map,
            nwps=nwps,
        )

        # Prepare features and target
        feature_lf = self._prepare_features(joined_lf)
        feature_cols = feature_lf.collect_schema().names()

        # Collect only necessary columns and drop nulls
        critical_cols = ["MW_or_MVA"]
        if nwps:
            critical_cols.extend(
                [
                    NwpColumns.TEMPERATURE_2M,
                    NwpColumns.SW_RADIATION,
                    NwpColumns.WIND_SPEED_10M,
                ]
            )

        # Apply random sampling if max_training_samples is set to prevent OOM errors
        # during collection.
        if config.max_training_samples is not None:
            log.info(f"Sampling training data to {config.max_training_samples} samples.")
            # LazyFrame doesn't have a direct sample(n=...) method, so we collect and then sample.
            # This still helps prevent OOM in XGBoost training itself, even if the collection
            # is the bottleneck.
            raw_df = cast(
                pl.DataFrame,
                joined_lf.select(list(set(feature_cols + ["MW_or_MVA"]))).collect(),
            ).sample(n=config.max_training_samples, seed=42)
        else:
            raw_df = cast(
                pl.DataFrame,
                joined_lf.select(list(set(feature_cols + ["MW_or_MVA"]))).collect(),
            )
        log.info(f"Collected raw_df shape before dropping nulls: {raw_df.shape}")
        joined_df = raw_df.drop_nulls(subset=critical_cols)

        dropped_rows = len(raw_df) - len(joined_df)
        if dropped_rows > 0:
            log.warning(
                f"Dropped {dropped_rows} rows during training due to nulls in critical columns: {critical_cols}"
            )

        if joined_df.is_empty():
            raise ValueError("No training data remaining after dropping nulls in critical columns.")

        SubstationFeatures.validate(
            joined_df, allow_missing_columns=True, allow_superfluous_columns=True
        )

        X = cast(pl.DataFrame, self._prepare_features(joined_df))
        y = joined_df.select("MW_or_MVA").to_series()

        # NaN/Inf checks
        if (
            X.select(
                pl.any_horizontal(
                    pl.col(pl.Float32, pl.Float64).is_nan()
                    | pl.col(pl.Float32, pl.Float64).is_infinite()
                )
            )
            .sum()
            .item()
            > 0
        ):
            raise ValueError("Input features X contain NaN or Inf values")

        if y.is_nan().any() or y.is_infinite().any():
            raise ValueError("Target y contains NaN or Inf values")

        # Save feature names
        self.feature_names = X.columns

        hyperparams = XGBoostHyperparameters(**config.hyperparameters)
        model = XGBRegressor(**hyperparams.model_dump())
        model.fit(X.to_arrow(), y.to_arrow())
        self.model = model

        return self

    def predict(
        self,
        substation_metadata: pt.DataFrame[SubstationMetadata],
        inference_params: InferenceParams,
        flows_30m: pl.LazyFrame,
        nwps: Mapping[NwpModel, pl.LazyFrame] | None = None,
        collapse_lead_times: bool = False,
    ) -> pt.DataFrame[PowerForecast]:
        """Execute the inference logic.

        Args:
            substation_metadata: The substation metadata containing h3 mapping.
            inference_params: Parameters for inference.
            flows_30m: Historical power flow data downsampled to 30m (for lags).
            nwps: A dictionary of weather forecast lazyframes.
            collapse_lead_times: Whether to collapse lead times to simulate real-time inference by keeping only the latest available NWP forecast for each valid time.

        Returns:
            A Patito DataFrame containing the predictions.
        """
        if self.model is None:
            raise ValueError("Model must be trained before calling predict.")

        if not nwps:
            raise ValueError("XGBoostForecaster requires NWP data for prediction.")

        if self.target_map is None:
            raise ValueError("target_map must be set before calling predict.")

        df_lf = self._prepare_data_for_model(
            flows_30m=flows_30m,
            substation_metadata=substation_metadata,
            target_map_df=self.target_map,
            nwps=nwps,
            inference_params=inference_params,
            collapse_lead_times=collapse_lead_times,
        )

        feature_lf = self._prepare_features(df_lf)
        feature_cols = feature_lf.collect_schema().names()

        # Output columns needed for the final result
        output_cols = [
            NwpColumns.VALID_TIME,
            "substation_number",
            NwpColumns.ENSEMBLE_MEMBER,
            NwpColumns.INIT_TIME,
            "nwp_init_hour",
            "lead_time_hours",
        ]

        # FIX: Remove drop_nulls logic during prediction
        df = cast(
            pl.DataFrame,
            df_lf.select(
                list(set(feature_cols + output_cols + ["MW_or_MVA", "peak_capacity_MW_or_MVA"]))
            ).collect(),
        )

        if df.is_empty():
            raise ValueError("No inference data remaining.")

        SubstationFeatures.validate(df, allow_missing_columns=True, allow_superfluous_columns=True)

        X = cast(pl.DataFrame, self._prepare_features(df))

        if (
            X.select(
                pl.any_horizontal(
                    pl.col(pl.Float32, pl.Float64).is_nan()
                    | pl.col(pl.Float32, pl.Float64).is_infinite()
                )
            )
            .sum()
            .item()
            > 0
        ):
            raise ValueError("Input features X contain NaN or Inf values")

        # Enforce exact column order from training
        if hasattr(self.model, "feature_names_in_"):
            X = X.select(self.model.feature_names_in_)
        elif hasattr(self, "feature_names") and self.feature_names:
            X = X.select(self.feature_names)

        preds = self.model.predict(X.to_arrow())

        # Descale predictions and return with correct schema
        fcst_init_time = inference_params.forecast_time
        model_name = inference_params.power_fcst_model_name or "xgboost_global"

        res = (
            df.with_columns(
                MW_or_MVA=pl.Series(values=preds, dtype=pl.Float32),
            )
            .with_columns(
                MW_or_MVA=pl.col("MW_or_MVA") * pl.col("peak_capacity_MW_or_MVA"),
                power_fcst_model_name=pl.lit(model_name).cast(pl.Categorical),
                power_fcst_init_time=pl.lit(fcst_init_time).cast(pl.Datetime("us", "UTC")),
                nwp_init_time=pl.col(NwpColumns.INIT_TIME).cast(pl.Datetime("us", "UTC")),
                power_fcst_init_year_month=pl.lit(fcst_init_time.strftime("%Y-%m")).cast(pl.String),
            )
            .select(
                [
                    NwpColumns.VALID_TIME,
                    "substation_number",
                    NwpColumns.ENSEMBLE_MEMBER,
                    "power_fcst_model_name",
                    "power_fcst_init_time",
                    "nwp_init_time",
                    "power_fcst_init_year_month",
                    "nwp_init_hour",
                    "lead_time_hours",
                    "MW_or_MVA",
                ]
            )
        )

        # Ensure all substations in the inference set are present in the target_map
        # to prevent silent dropping of forecasts.
        if res.select("substation_number").n_unique() < df.select("substation_number").n_unique():
            missing_substations = set(df.select("substation_number").to_series()) - set(
                res.select("substation_number").to_series()
            )
            raise ValueError(
                f"The following substations are missing from the target_map: {missing_substations}. "
                "All substations in the inference set must have a corresponding entry in the "
                "target_map to prevent null forecasts."
            )

        # Handle potential nulls in ensemble_member (required by schema)
        res = res.with_columns(pl.col(NwpColumns.ENSEMBLE_MEMBER).fill_null(0).cast(pl.UInt8))

        return PowerForecast.validate(res)
Functions
__init__(model=None)

Initialize the forecaster.

Parameters:

Name Type Description Default
model XGBRegressor | None

An optional pre-trained XGBoost model.

None
Source code in packages/xgboost_forecaster/src/xgboost_forecaster/model.py
47
48
49
50
51
52
53
54
55
def __init__(self, model: XGBRegressor | None = None):
    """Initialize the forecaster.

    Args:
        model: An optional pre-trained XGBoost model.
    """
    self.model = model
    self.target_map = None
    self.feature_names = []
log_model(model_name)

Log the model to MLflow.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/model.py
73
74
75
76
77
78
79
80
81
82
83
84
def log_model(self, model_name: str) -> None:
    """Log the model to MLflow."""
    if self.model is not None:
        mlflow.xgboost.log_model(self.model, artifact_path="model")

    if self.target_map is not None:
        # Ensure target_map is a DataFrame before writing to JSON
        target_map_df = self._get_target_map_df()
        with tempfile.TemporaryDirectory() as tmpdir:
            path = os.path.join(tmpdir, "target_map.json")
            target_map_df.write_json(path)
            mlflow.log_artifact(path, artifact_path="metadata")
predict(substation_metadata, inference_params, flows_30m, nwps=None, collapse_lead_times=False)

Execute the inference logic.

Parameters:

Name Type Description Default
substation_metadata DataFrame[SubstationMetadata]

The substation metadata containing h3 mapping.

required
inference_params InferenceParams

Parameters for inference.

required
flows_30m LazyFrame

Historical power flow data downsampled to 30m (for lags).

required
nwps Mapping[NwpModel, LazyFrame] | None

A dictionary of weather forecast lazyframes.

None
collapse_lead_times bool

Whether to collapse lead times to simulate real-time inference by keeping only the latest available NWP forecast for each valid time.

False

Returns:

Type Description
DataFrame[PowerForecast]

A Patito DataFrame containing the predictions.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/model.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
def predict(
    self,
    substation_metadata: pt.DataFrame[SubstationMetadata],
    inference_params: InferenceParams,
    flows_30m: pl.LazyFrame,
    nwps: Mapping[NwpModel, pl.LazyFrame] | None = None,
    collapse_lead_times: bool = False,
) -> pt.DataFrame[PowerForecast]:
    """Execute the inference logic.

    Args:
        substation_metadata: The substation metadata containing h3 mapping.
        inference_params: Parameters for inference.
        flows_30m: Historical power flow data downsampled to 30m (for lags).
        nwps: A dictionary of weather forecast lazyframes.
        collapse_lead_times: Whether to collapse lead times to simulate real-time inference by keeping only the latest available NWP forecast for each valid time.

    Returns:
        A Patito DataFrame containing the predictions.
    """
    if self.model is None:
        raise ValueError("Model must be trained before calling predict.")

    if not nwps:
        raise ValueError("XGBoostForecaster requires NWP data for prediction.")

    if self.target_map is None:
        raise ValueError("target_map must be set before calling predict.")

    df_lf = self._prepare_data_for_model(
        flows_30m=flows_30m,
        substation_metadata=substation_metadata,
        target_map_df=self.target_map,
        nwps=nwps,
        inference_params=inference_params,
        collapse_lead_times=collapse_lead_times,
    )

    feature_lf = self._prepare_features(df_lf)
    feature_cols = feature_lf.collect_schema().names()

    # Output columns needed for the final result
    output_cols = [
        NwpColumns.VALID_TIME,
        "substation_number",
        NwpColumns.ENSEMBLE_MEMBER,
        NwpColumns.INIT_TIME,
        "nwp_init_hour",
        "lead_time_hours",
    ]

    # FIX: Remove drop_nulls logic during prediction
    df = cast(
        pl.DataFrame,
        df_lf.select(
            list(set(feature_cols + output_cols + ["MW_or_MVA", "peak_capacity_MW_or_MVA"]))
        ).collect(),
    )

    if df.is_empty():
        raise ValueError("No inference data remaining.")

    SubstationFeatures.validate(df, allow_missing_columns=True, allow_superfluous_columns=True)

    X = cast(pl.DataFrame, self._prepare_features(df))

    if (
        X.select(
            pl.any_horizontal(
                pl.col(pl.Float32, pl.Float64).is_nan()
                | pl.col(pl.Float32, pl.Float64).is_infinite()
            )
        )
        .sum()
        .item()
        > 0
    ):
        raise ValueError("Input features X contain NaN or Inf values")

    # Enforce exact column order from training
    if hasattr(self.model, "feature_names_in_"):
        X = X.select(self.model.feature_names_in_)
    elif hasattr(self, "feature_names") and self.feature_names:
        X = X.select(self.feature_names)

    preds = self.model.predict(X.to_arrow())

    # Descale predictions and return with correct schema
    fcst_init_time = inference_params.forecast_time
    model_name = inference_params.power_fcst_model_name or "xgboost_global"

    res = (
        df.with_columns(
            MW_or_MVA=pl.Series(values=preds, dtype=pl.Float32),
        )
        .with_columns(
            MW_or_MVA=pl.col("MW_or_MVA") * pl.col("peak_capacity_MW_or_MVA"),
            power_fcst_model_name=pl.lit(model_name).cast(pl.Categorical),
            power_fcst_init_time=pl.lit(fcst_init_time).cast(pl.Datetime("us", "UTC")),
            nwp_init_time=pl.col(NwpColumns.INIT_TIME).cast(pl.Datetime("us", "UTC")),
            power_fcst_init_year_month=pl.lit(fcst_init_time.strftime("%Y-%m")).cast(pl.String),
        )
        .select(
            [
                NwpColumns.VALID_TIME,
                "substation_number",
                NwpColumns.ENSEMBLE_MEMBER,
                "power_fcst_model_name",
                "power_fcst_init_time",
                "nwp_init_time",
                "power_fcst_init_year_month",
                "nwp_init_hour",
                "lead_time_hours",
                "MW_or_MVA",
            ]
        )
    )

    # Ensure all substations in the inference set are present in the target_map
    # to prevent silent dropping of forecasts.
    if res.select("substation_number").n_unique() < df.select("substation_number").n_unique():
        missing_substations = set(df.select("substation_number").to_series()) - set(
            res.select("substation_number").to_series()
        )
        raise ValueError(
            f"The following substations are missing from the target_map: {missing_substations}. "
            "All substations in the inference set must have a corresponding entry in the "
            "target_map to prevent null forecasts."
        )

    # Handle potential nulls in ensemble_member (required by schema)
    res = res.with_columns(pl.col(NwpColumns.ENSEMBLE_MEMBER).fill_null(0).cast(pl.UInt8))

    return PowerForecast.validate(res)
train(config, flows_30m, substation_metadata, nwps=None)

Train the XGBoost model.

Parameters:

Name Type Description Default
config ModelConfig

The model configuration object.

required
flows_30m LazyFrame

Historical power flow data downsampled to 30m.

required
substation_metadata DataFrame[SubstationMetadata]

The substation metadata containing h3 mapping.

required
nwps Mapping[NwpModel, LazyFrame] | None

A dictionary of weather forecast dataframes.

None

Returns:

Type Description
XGBoostForecaster

The trained XGBoostForecaster instance.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/model.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
def train(
    self,
    config: ModelConfig,
    flows_30m: pl.LazyFrame,
    substation_metadata: pt.DataFrame[SubstationMetadata],
    nwps: Mapping[NwpModel, pl.LazyFrame] | None = None,
) -> "XGBoostForecaster":
    """Train the XGBoost model.

    Args:
        config: The model configuration object.
        flows_30m: Historical power flow data downsampled to 30m.
        substation_metadata: The substation metadata containing h3 mapping.
        nwps: A dictionary of weather forecast dataframes.

    Returns:
        The trained XGBoostForecaster instance.
    """
    self.config = config
    log.info("Starting XGBoost training...")

    # Log input data info
    # Note: We don't collect the full LazyFrames here to avoid OOM,
    # just logging their presence and schema.
    log.info(f"Input flows_30m columns: {flows_30m.collect_schema().names()}")
    if nwps:
        for name, lf in nwps.items():
            log.info(f"Input NWP {name.value} columns: {lf.collect_schema().names()}")

    if len(config.features.nwps) > 0 and not nwps:
        raise ValueError("Model config requires NWPs, but none were provided.")

    if self.target_map is None:
        raise ValueError("target_map must be set before calling train.")

    joined_lf = self._prepare_data_for_model(
        flows_30m=flows_30m,
        substation_metadata=substation_metadata,
        target_map_df=self.target_map,
        nwps=nwps,
    )

    # Prepare features and target
    feature_lf = self._prepare_features(joined_lf)
    feature_cols = feature_lf.collect_schema().names()

    # Collect only necessary columns and drop nulls
    critical_cols = ["MW_or_MVA"]
    if nwps:
        critical_cols.extend(
            [
                NwpColumns.TEMPERATURE_2M,
                NwpColumns.SW_RADIATION,
                NwpColumns.WIND_SPEED_10M,
            ]
        )

    # Apply random sampling if max_training_samples is set to prevent OOM errors
    # during collection.
    if config.max_training_samples is not None:
        log.info(f"Sampling training data to {config.max_training_samples} samples.")
        # LazyFrame doesn't have a direct sample(n=...) method, so we collect and then sample.
        # This still helps prevent OOM in XGBoost training itself, even if the collection
        # is the bottleneck.
        raw_df = cast(
            pl.DataFrame,
            joined_lf.select(list(set(feature_cols + ["MW_or_MVA"]))).collect(),
        ).sample(n=config.max_training_samples, seed=42)
    else:
        raw_df = cast(
            pl.DataFrame,
            joined_lf.select(list(set(feature_cols + ["MW_or_MVA"]))).collect(),
        )
    log.info(f"Collected raw_df shape before dropping nulls: {raw_df.shape}")
    joined_df = raw_df.drop_nulls(subset=critical_cols)

    dropped_rows = len(raw_df) - len(joined_df)
    if dropped_rows > 0:
        log.warning(
            f"Dropped {dropped_rows} rows during training due to nulls in critical columns: {critical_cols}"
        )

    if joined_df.is_empty():
        raise ValueError("No training data remaining after dropping nulls in critical columns.")

    SubstationFeatures.validate(
        joined_df, allow_missing_columns=True, allow_superfluous_columns=True
    )

    X = cast(pl.DataFrame, self._prepare_features(joined_df))
    y = joined_df.select("MW_or_MVA").to_series()

    # NaN/Inf checks
    if (
        X.select(
            pl.any_horizontal(
                pl.col(pl.Float32, pl.Float64).is_nan()
                | pl.col(pl.Float32, pl.Float64).is_infinite()
            )
        )
        .sum()
        .item()
        > 0
    ):
        raise ValueError("Input features X contain NaN or Inf values")

    if y.is_nan().any() or y.is_infinite().any():
        raise ValueError("Target y contains NaN or Inf values")

    # Save feature names
    self.feature_names = X.columns

    hyperparams = XGBoostHyperparameters(**config.hyperparameters)
    model = XGBRegressor(**hyperparams.model_dump())
    model.fit(X.to_arrow(), y.to_arrow())
    self.model = model

    return self

Functions

xgboost_forecaster.scaling

Classes

xgboost_forecaster.types

Classes

EnsembleSelection

Bases: str, Enum

Selection method for weather ensemble members.

Source code in packages/xgboost_forecaster/src/xgboost_forecaster/types.py
4
5
6
7
8
9
class EnsembleSelection(str, Enum):
    """Selection method for weather ensemble members."""

    MEAN = "mean"
    SINGLE = "single"
    ALL = "all"