Skip to content

Contracts API

Contracts

Defines the "data contracts": the schemas defining the precise shape of each data source, and the semantics.

Dependency Isolation

This package is designed to be extremely lightweight. It defines the shape of the data using Patito and Polars, but it does not contain any ML-specific logic or heavy dependencies like MLflow. This ensures that any component in the system (e.g., a data ingestion script or a dashboard) can import these schemas without bringing in the entire ML stack.

Key Data Contracts

  • SubstationFeatures: The final joined dataset ready for ML model training and inference. It enforces strict validation for critical features, including the dynamically calculated latest_available_weekly_power_lag (which prevents lookahead bias) and the ensemble_member field.
  • PowerForecast: The schema for deterministic ensemble forecasts generated by the ML models. It includes fields for tracking the model name, initialization times (both power forecast and underlying NWP), and the predicted power flow.
  • ProcessedNwp: Weather data after ensemble selection and interpolation.
  • ModelConfig: Configuration schema for ML models, defining hyperparameters, feature selection, and critical forecasting parameters such as:
  • required_lookback_days: The amount of historical data required for dynamic autoregressive lags (e.g., 14 days).
  • nwp_availability_delay_hours: The delay before an NWP forecast becomes available for inference (e.g., 3 hours), used to prevent lookahead bias during backtesting.

Design principals

  • Naming of columns: Prefer snake_case, except for acronyms or SI units. For example, capitalise "DER" (the acronym of distributed energy resource) and use upper case for "MW" (megawatts).
  • Semantic checks: Checking that a value is within range should be fairly generous. The aim is to catch physically impossible values, rather than possible-but-unlikely values.

contracts.data_schemas

Data schemas for the NGED substation forecast project.

Classes

H3GridWeights

Bases: Model

Schema for the pre-computed H3 grid weights.

This contract defines the mapping between H3 hexagons and a regular latitude/longitude grid. It is used to ensure type safety when passing spatial mapping data from generic geospatial utilities (like packages/geo) to dataset-specific ingestion pipelines (like packages/dynamical_data).

Source code in packages/contracts/src/contracts/data_schemas.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
class H3GridWeights(pt.Model):
    """Schema for the pre-computed H3 grid weights.

    This contract defines the mapping between H3 hexagons and a regular latitude/longitude grid.
    It is used to ensure type safety when passing spatial mapping data from generic geospatial
    utilities (like `packages/geo`) to dataset-specific ingestion pipelines (like `packages/dynamical_data`).
    """

    h3_index: int = pt.Field(dtype=pl.UInt64)
    nwp_lat: float = pt.Field(dtype=pl.Float64, ge=-90, le=90)
    nwp_lng: float = pt.Field(dtype=pl.Float64, ge=-180, le=180)
    len: int = pt.Field(dtype=pl.UInt32)
    total: int = pt.Field(dtype=pl.UInt32)
    proportion: float = pt.Field(dtype=pl.Float64)

InferenceParams

Bases: BaseModel

Parameters for ML model inference.

Source code in packages/contracts/src/contracts/data_schemas.py
229
230
231
232
233
234
235
236
237
class InferenceParams(BaseModel):
    """Parameters for ML model inference."""

    # The time that we create our power forecast. This might be called `t0` in some other OCF
    # projects. When running backtests, we cannot use any NWPs available after `forecast_time`
    # (i.e. init_time + delay > forecast_time).
    forecast_time: datetime

    power_fcst_model_name: str | None = None

MissingCorePowerVariablesError

Bases: ValueError

Raised when a substation CSV lacks both MW and MVA data.

Source code in packages/contracts/src/contracts/data_schemas.py
25
26
27
28
class MissingCorePowerVariablesError(ValueError):
    """Raised when a substation CSV lacks both MW and MVA data."""

    pass

Nwp

Bases: Model

Weather data schema for NWP forecasts.

Source code in packages/contracts/src/contracts/data_schemas.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class Nwp(pt.Model):
    """Weather data schema for NWP forecasts."""

    init_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    valid_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    ensemble_member: int = pt.Field(dtype=pl.UInt8)
    h3_index: int = pt.Field(dtype=pl.UInt64)

    # Variables stored as Float32 in memory (descaled from uint8 on disk)
    temperature_2m: float = pt.Field(dtype=pl.Float32)
    dew_point_temperature_2m: float = pt.Field(dtype=pl.Float32)
    # WIND VECTOR COMPONENTS:
    # We store raw U and V components as Float32 to allow physically realistic
    # linear interpolation in the forecasting pipeline, avoiding the "phantom high wind"
    # artifacts caused by interpolating speed/direction or circular variables.
    wind_u_10m: float = pt.Field(dtype=pl.Float32)
    wind_v_10m: float = pt.Field(dtype=pl.Float32)
    wind_u_100m: float = pt.Field(dtype=pl.Float32)
    wind_v_100m: float = pt.Field(dtype=pl.Float32)
    pressure_surface: float = pt.Field(dtype=pl.Float32)
    pressure_reduced_to_mean_sea_level: float = pt.Field(dtype=pl.Float32)
    geopotential_height_500hpa: float = pt.Field(dtype=pl.Float32)

    # Precipitation and radiation variables are null for the first forecast step (lead time 0) in
    # ECMWF ENS. Also note that, whilst these variables accumulate over forecast steps in ECMWF's
    # raw forecasts, we get ECMWF ENS from Dynamical.org, and Dynamical.org de-accumulates these
    # values before we receive them. So these are true _rates_.
    downward_long_wave_radiation_flux_surface: float | None = pt.Field(dtype=pl.Float32)
    downward_short_wave_radiation_flux_surface: float | None = pt.Field(dtype=pl.Float32)
    precipitation_surface: float | None = pt.Field(dtype=pl.Float32)

    categorical_precipitation_type_surface: int = pt.Field(dtype=pl.UInt8)

    @classmethod
    def validate(
        cls,
        dataframe: pl.DataFrame | "pd.DataFrame",
        columns: Sequence[str] | None = None,
        allow_missing_columns: bool = False,
        allow_superfluous_columns: bool = False,
        drop_superfluous_columns: bool = False,
    ) -> pt.DataFrame["Nwp"]:
        """Validate the given dataframe, ensuring no nulls from second step onwards."""
        validated_df = super().validate(
            dataframe=dataframe,
            columns=columns,
            allow_missing_columns=allow_missing_columns,
            allow_superfluous_columns=allow_superfluous_columns,
            drop_superfluous_columns=drop_superfluous_columns,
        )

        # Check for nulls from second forecast step onwards
        # (i.e. where valid_time > init_time)
        cols_to_check = [
            "precipitation_surface",
            "downward_short_wave_radiation_flux_surface",
            "downward_long_wave_radiation_flux_surface",
        ]

        second_step_onwards = validated_df.filter(pl.col("valid_time") > pl.col("init_time"))

        for col in cols_to_check:
            null_count = second_step_onwards.select(pl.col(col).is_null().sum()).item()
            if null_count > 0:
                raise ValueError(
                    f"Column '{col}' contains {null_count} null values from the second forecast "
                    "step onwards. These variables are only allowed to be null for the first "
                    "forecast step (lead time 0)."
                )

        return cast(pt.DataFrame["Nwp"], validated_df)
Functions
validate(dataframe, columns=None, allow_missing_columns=False, allow_superfluous_columns=False, drop_superfluous_columns=False) classmethod

Validate the given dataframe, ensuring no nulls from second step onwards.

Source code in packages/contracts/src/contracts/data_schemas.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
@classmethod
def validate(
    cls,
    dataframe: pl.DataFrame | "pd.DataFrame",
    columns: Sequence[str] | None = None,
    allow_missing_columns: bool = False,
    allow_superfluous_columns: bool = False,
    drop_superfluous_columns: bool = False,
) -> pt.DataFrame["Nwp"]:
    """Validate the given dataframe, ensuring no nulls from second step onwards."""
    validated_df = super().validate(
        dataframe=dataframe,
        columns=columns,
        allow_missing_columns=allow_missing_columns,
        allow_superfluous_columns=allow_superfluous_columns,
        drop_superfluous_columns=drop_superfluous_columns,
    )

    # Check for nulls from second forecast step onwards
    # (i.e. where valid_time > init_time)
    cols_to_check = [
        "precipitation_surface",
        "downward_short_wave_radiation_flux_surface",
        "downward_long_wave_radiation_flux_surface",
    ]

    second_step_onwards = validated_df.filter(pl.col("valid_time") > pl.col("init_time"))

    for col in cols_to_check:
        null_count = second_step_onwards.select(pl.col(col).is_null().sum()).item()
        if null_count > 0:
            raise ValueError(
                f"Column '{col}' contains {null_count} null values from the second forecast "
                "step onwards. These variables are only allowed to be null for the first "
                "forecast step (lead time 0)."
            )

    return cast(pt.DataFrame["Nwp"], validated_df)

NwpColumns

Centralized constants for NWP column names.

Used to prevent typos and ensure consistency across feature engineering and model training.

Source code in packages/contracts/src/contracts/data_schemas.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
class NwpColumns:
    """Centralized constants for NWP column names.

    Used to prevent typos and ensure consistency across feature engineering and model training.
    """

    VALID_TIME = "valid_time"
    INIT_TIME = "init_time"
    LEAD_TIME_HOURS = "lead_time_hours"
    H3_INDEX = "h3_index"
    ENSEMBLE_MEMBER = "ensemble_member"
    TEMPERATURE_2M = "temperature_2m"
    WIND_SPEED_10M = "wind_speed_10m"
    WIND_DIRECTION_10M = "wind_direction_10m"
    SW_RADIATION = "downward_short_wave_radiation_flux_surface"

PowerForecast

Bases: Model

Forecast data schema for deterministic ensemble forecasts.

Source code in packages/contracts/src/contracts/data_schemas.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class PowerForecast(pt.Model):
    """Forecast data schema for deterministic ensemble forecasts."""

    valid_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    substation_number: int = pt.Field(dtype=pl.Int32)
    ensemble_member: int = pt.Field(dtype=pl.UInt8)

    # The datetime that the underlying weather forecast was initialised.
    nwp_init_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)

    # The hour of the day that the weather forecast was initialised (0, 6, 12, 18).
    nwp_init_hour: int = pt.Field(dtype=pl.Int32)

    # The number of hours between the weather forecast initialisation and the valid time.
    lead_time_hours: float = pt.Field(dtype=pl.Float32)

    # Identifier for our ML-based power forecasting model.
    # This is manually specified in `hydra_schemas.ModelConfig.power_fcst_model_name`.
    power_fcst_model_name: str = pt.Field(dtype=pl.Categorical)

    # The datetime that the power forecast was initialised.
    power_fcst_init_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)

    # Year and month of the power forecast initialisation (for partitioning).
    power_fcst_init_year_month: str = pt.Field(dtype=pl.String)

    # The power forecast itself in units of MW (active power) or MVA (apparent power).
    MW_or_MVA: float = pt.Field(dtype=pl.Float32)

ProcessedNwp

Bases: Model

Weather data after ensemble selection and interpolation.

Note: Accumulated variables (e.g., precipitation, radiation) are already de-accumulated by Dynamical.org prior to download, and should not be differenced.

Clever Optimization: To save memory, weather variables are scaled to a 0-255 range (uint8) before being saved to disk. The scaling formula is: uint8_value = round(((physical_value - buffered_min) / buffered_range) * 255).

When loaded, they are cast to Float32 but retain the 0-255 scale.

Source code in packages/contracts/src/contracts/data_schemas.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class ProcessedNwp(pt.Model):
    """Weather data after ensemble selection and interpolation.

    Note: Accumulated variables (e.g., precipitation, radiation) are already
    de-accumulated by Dynamical.org prior to download, and should not be differenced.

    Clever Optimization:
    To save memory, weather variables are scaled to a 0-255 range (uint8) before being saved to disk.
    The scaling formula is:
        uint8_value = round(((physical_value - buffered_min) / buffered_range) * 255).

    When loaded, they are cast to Float32 but retain the 0-255 scale.
    """

    valid_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    init_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    lead_time_hours: float = pt.Field(dtype=pl.Float32)
    h3_index: int = pt.Field(dtype=pl.UInt64)
    ensemble_member: int | None = pt.Field(dtype=pl.UInt8, allow_missing=True)

    # Weather variables as Float32
    temperature_2m: float = pt.Field(dtype=pl.Float32)
    dew_point_temperature_2m: float = pt.Field(dtype=pl.Float32)
    wind_speed_10m: float = pt.Field(dtype=pl.Float32)
    wind_direction_10m: float = pt.Field(dtype=pl.Float32)
    wind_speed_100m: float = pt.Field(dtype=pl.Float32)
    wind_direction_100m: float = pt.Field(dtype=pl.Float32)
    pressure_surface: float = pt.Field(dtype=pl.Float32)
    pressure_reduced_to_mean_sea_level: float = pt.Field(dtype=pl.Float32)
    geopotential_height_500hpa: float = pt.Field(dtype=pl.Float32)
    downward_long_wave_radiation_flux_surface: float | None = pt.Field(dtype=pl.Float32)
    downward_short_wave_radiation_flux_surface: float | None = pt.Field(dtype=pl.Float32)
    precipitation_surface: float | None = pt.Field(dtype=pl.Float32)
    categorical_precipitation_type_surface: int = pt.Field(dtype=pl.UInt8)

ScalingParams

Bases: Model

Schema for weather variable scaling parameters.

Used when scaling between physical units (e.g. degrees C) and their unsigned 8-bit integer (uint8) representations. uint8 represents integers in the range [0, 255].

Source code in packages/contracts/src/contracts/data_schemas.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class ScalingParams(pt.Model):
    """Schema for weather variable scaling parameters.

    Used when scaling between physical units (e.g. degrees C) and their unsigned 8-bit integer
    (uint8) representations. uint8 represents integers in the range [0, 255]."""

    col_name: str = pt.Field(dtype=pl.String)

    # The minimum from the actual values, minus a small buffer.
    buffered_min: float = pt.Field(dtype=pl.Float32)

    # The range from the actual values, plus a small buffer.
    buffered_range: float = pt.Field(dtype=pl.Float32)

    # The maximum from the actual values, plus a small buffer.
    buffered_max: float = pt.Field(dtype=pl.Float32)

SimplifiedSubstationPowerFlows

Bases: Model

Standardized, single-column representation of power flows.

This model is used after the best available power column (MW or MVA) has been selected and renamed to 'MW_or_MVA'.

Source code in packages/contracts/src/contracts/data_schemas.py
103
104
105
106
107
108
109
110
111
112
class SimplifiedSubstationPowerFlows(pt.Model):
    """Standardized, single-column representation of power flows.

    This model is used after the best available power column (MW or MVA) has been
    selected and renamed to 'MW_or_MVA'.
    """

    timestamp: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    substation_number: int = pt.Field(dtype=pl.Int32)
    MW_or_MVA: float = pt.Field(dtype=pl.Float32, ge=-1_000, le=1_000)

SubstationFeatures

Bases: Model

Final joined dataset ready for XGBoost.

Weather features are kept in their physical units (e.g., degrees Celsius, m/s) to ensure precision during interpolation and feature engineering.

Source code in packages/contracts/src/contracts/data_schemas.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class SubstationFeatures(pt.Model):
    """Final joined dataset ready for XGBoost.

    Weather features are kept in their physical units (e.g., degrees Celsius, m/s)
    to ensure precision during interpolation and feature engineering.
    """

    valid_time: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)
    substation_number: int = pt.Field(dtype=pl.Int32)
    ensemble_member: int | None = pt.Field(dtype=pl.UInt8, allow_missing=True)
    MW_or_MVA: float = pt.Field(dtype=pl.Float32)
    lead_time_hours: float = pt.Field(dtype=pl.Float32)
    lead_time_days: float = pt.Field(dtype=pl.Float32)
    nwp_init_hour: int = pt.Field(dtype=pl.Int32)

    # Power lags
    latest_available_weekly_power_lag: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)

    # Weather features
    temperature_2m: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    dew_point_temperature_2m: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    # PHYSICAL WIND FEATURES:
    # These are calculated from interpolated U/V components in the forecasting
    # pipeline, ensuring physically realistic wind speed and direction.
    wind_speed_10m: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    wind_direction_10m: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    wind_speed_100m: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    wind_direction_100m: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    pressure_surface: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    pressure_reduced_to_mean_sea_level: float | None = pt.Field(
        dtype=pl.Float32, allow_missing=True
    )
    geopotential_height_500hpa: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    downward_long_wave_radiation_flux_surface: float | None = pt.Field(
        dtype=pl.Float32, allow_missing=True
    )
    downward_short_wave_radiation_flux_surface: float | None = pt.Field(
        dtype=pl.Float32, allow_missing=True
    )
    precipitation_surface: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    categorical_precipitation_type_surface: int | None = pt.Field(
        dtype=pl.UInt8, allow_missing=True
    )

    # Physical features
    windchill: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)

    # Weather lags/trends
    temperature_2m_lag_7d: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    downward_short_wave_radiation_flux_surface_lag_7d: float | None = pt.Field(
        dtype=pl.Float32, allow_missing=True
    )
    temperature_2m_lag_14d: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    downward_short_wave_radiation_flux_surface_lag_14d: float | None = pt.Field(
        dtype=pl.Float32, allow_missing=True
    )
    temperature_2m_6h_ago: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)
    temp_trend_6h: float | None = pt.Field(dtype=pl.Float32, allow_missing=True)

    # Temporal features
    hour_sin: float = pt.Field(dtype=pl.Float32)
    hour_cos: float = pt.Field(dtype=pl.Float32)
    day_of_year_sin: float = pt.Field(dtype=pl.Float32)
    day_of_year_cos: float = pt.Field(dtype=pl.Float32)
    day_of_week: int = pt.Field(dtype=pl.Int8)

SubstationLocations

Bases: Model

The data structure of the raw substation location data from NGED.

Source code in packages/contracts/src/contracts/data_schemas.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class SubstationLocations(pt.Model):
    """The data structure of the raw substation location data from NGED."""

    # NGED has 192,000 substations.
    substation_number: int = pt.Field(dtype=pl.Int32, unique=True, gt=0, lt=1_000_000)

    # The min and max string lengths are actually 3 and 48 chars, respectively.
    # Note that there are two "Park Lane" substations, with different locations and different
    # substation numbers.
    substation_name: str = pt.Field(dtype=pl.String, min_length=2, max_length=64)

    substation_type: str = pt.Field(dtype=pl.Categorical)
    latitude: float | None = pt.Field(dtype=pl.Float32, ge=49, le=61)  # UK latitude range
    longitude: float | None = pt.Field(dtype=pl.Float32, ge=-9, le=2)  # UK longitude range

SubstationLocationsWithH3

Bases: SubstationLocations

Substation locations including their H3 index.

Source code in packages/contracts/src/contracts/data_schemas.py
143
144
145
146
class SubstationLocationsWithH3(SubstationLocations):
    """Substation locations including their H3 index."""

    h3_res_5: int | None = pt.Field(dtype=pl.UInt64)

SubstationMetadata

Bases: Model

Metadata for a substation, joining location data with live telemetry info.

Source code in packages/contracts/src/contracts/data_schemas.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class SubstationMetadata(pt.Model):
    """Metadata for a substation, joining location data with live telemetry info."""

    # NGED has 192,000 substations.
    substation_number: int = pt.Field(dtype=pl.Int32, unique=True, gt=0, lt=1_000_000)

    # NGED's CKAN portal uses slightly different names for some substations in their location table
    # versus in their live primary flows data. These names are matched by code in
    # `packages/nged_data/src/nged_data/substation_names/`
    substation_name_in_location_table: str = pt.Field(dtype=pl.String, min_length=2, max_length=64)

    # This will be null if the substation doesn't have live telemetry.
    substation_name_in_live_primaries: str | None = pt.Field(
        dtype=pl.String, min_length=2, max_length=128, allow_missing=True
    )

    # The URL to the live telemetry CSV on NGED's CKAN portal.
    url: str | None = pt.Field(dtype=pl.String, allow_missing=True)

    substation_type: str = pt.Field(dtype=pl.Categorical)
    latitude: float | None = pt.Field(dtype=pl.Float32, ge=49, le=61)  # UK latitude range
    longitude: float | None = pt.Field(dtype=pl.Float32, ge=-9, le=2)  # UK longitude range
    h3_res_5: int | None = pt.Field(dtype=pl.UInt64)  # H3 discrete spatial index

    # When this metadata record was last updated from the upstream NGED datasets.
    last_updated: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)

    # A globally computed preference for which power column (MW or MVA) to use, based on full history.
    # This prioritizes MW but falls back to MVA if MW is unavailable or contains dead sensors.
    preferred_power_col: str | None = pt.Field(dtype=pl.String, allow_missing=True)

SubstationPowerFlows

Bases: Model

Source code in packages/contracts/src/contracts/data_schemas.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class SubstationPowerFlows(pt.Model):
    timestamp: datetime = pt.Field(dtype=UTC_DATETIME_DTYPE)

    # The unique identifier for the substation.
    substation_number: int = pt.Field(dtype=pl.Int32)

    # Primary substations usually have flows in the tens of MW.
    # We'll set a loose range for now to catch extreme errors.
    # If we want to reduce storage space we could store kW and kVAr as Int16.

    # Active power:
    MW: float | None = pt.Field(dtype=pl.Float32, ge=-1_000, le=1_000)

    # Apparent power:
    MVA: float | None = pt.Field(dtype=pl.Float32, ge=-1_000, le=1_000)

    # Reactive power:
    MVAr: float | None = pt.Field(dtype=pl.Float32, ge=-1_000, le=1_000)

    # The datetime this data was ingested into our system. When we update our datasets, we examine
    # `ingested_at` to figure out whether we need to get new data from NGED for this substation.
    # `ingested_at` is only missing for data ingested before around mid-March 2026 (prior to this,
    # we didn't record when the data was ingested).
    ingested_at: datetime | None = pt.Field(dtype=UTC_DATETIME_DTYPE)

    @classmethod
    def validate(
        cls,
        dataframe: pl.DataFrame | "pd.DataFrame",
        columns: Sequence[str] | None = None,
        allow_missing_columns: bool = False,
        allow_superfluous_columns: bool = False,
        drop_superfluous_columns: bool = False,
    ) -> pt.DataFrame["SubstationPowerFlows"]:
        """Validate the given dataframe, ensuring either MW or MVA is present and has data.

        NOTE: Fully null DataFrames are allowed to handle edge cases where:
        1. An entire partition's data was cleaned and all values marked as stuck/insane
        2. Ingestion failed completely for a partition (empty DataFrame after filtering)

        In these cases, the validation passes through to the parent class which allows
        null values for the columns. This prevents pipeline crashes from legitimate empty
        data scenarios. The downstream model training logic will need to handle fully
        null target variables by either skipping training or using fallback strategies.
        """
        # Ensure at least one of MW or MVA has non-null data
        # Only raise error if there IS data but MW/MVA columns have no non-null values.
        if len(dataframe) > 0:
            mw_has_data = POWER_MW in dataframe.columns and dataframe[POWER_MW].is_not_null().any()
            mva_has_data = (
                POWER_MVA in dataframe.columns and dataframe[POWER_MVA].is_not_null().any()
            )

            if not mw_has_data and not mva_has_data:
                raise MissingCorePowerVariablesError(
                    f"SubstationPowerFlows dataframe must have non-null data in either '{POWER_MW}' "
                    f"or '{POWER_MVA}' unless the entire DataFrame is empty (which is allowed for "
                    "edge cases)."
                )

        return cast(
            pt.DataFrame["SubstationPowerFlows"],
            super().validate(
                dataframe=dataframe,
                columns=columns,
                allow_missing_columns=allow_missing_columns,
                allow_superfluous_columns=allow_superfluous_columns,
                drop_superfluous_columns=drop_superfluous_columns,
            ),
        )
Functions
validate(dataframe, columns=None, allow_missing_columns=False, allow_superfluous_columns=False, drop_superfluous_columns=False) classmethod

Validate the given dataframe, ensuring either MW or MVA is present and has data.

NOTE: Fully null DataFrames are allowed to handle edge cases where: 1. An entire partition's data was cleaned and all values marked as stuck/insane 2. Ingestion failed completely for a partition (empty DataFrame after filtering)

In these cases, the validation passes through to the parent class which allows null values for the columns. This prevents pipeline crashes from legitimate empty data scenarios. The downstream model training logic will need to handle fully null target variables by either skipping training or using fallback strategies.

Source code in packages/contracts/src/contracts/data_schemas.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@classmethod
def validate(
    cls,
    dataframe: pl.DataFrame | "pd.DataFrame",
    columns: Sequence[str] | None = None,
    allow_missing_columns: bool = False,
    allow_superfluous_columns: bool = False,
    drop_superfluous_columns: bool = False,
) -> pt.DataFrame["SubstationPowerFlows"]:
    """Validate the given dataframe, ensuring either MW or MVA is present and has data.

    NOTE: Fully null DataFrames are allowed to handle edge cases where:
    1. An entire partition's data was cleaned and all values marked as stuck/insane
    2. Ingestion failed completely for a partition (empty DataFrame after filtering)

    In these cases, the validation passes through to the parent class which allows
    null values for the columns. This prevents pipeline crashes from legitimate empty
    data scenarios. The downstream model training logic will need to handle fully
    null target variables by either skipping training or using fallback strategies.
    """
    # Ensure at least one of MW or MVA has non-null data
    # Only raise error if there IS data but MW/MVA columns have no non-null values.
    if len(dataframe) > 0:
        mw_has_data = POWER_MW in dataframe.columns and dataframe[POWER_MW].is_not_null().any()
        mva_has_data = (
            POWER_MVA in dataframe.columns and dataframe[POWER_MVA].is_not_null().any()
        )

        if not mw_has_data and not mva_has_data:
            raise MissingCorePowerVariablesError(
                f"SubstationPowerFlows dataframe must have non-null data in either '{POWER_MW}' "
                f"or '{POWER_MVA}' unless the entire DataFrame is empty (which is allowed for "
                "edge cases)."
            )

    return cast(
        pt.DataFrame["SubstationPowerFlows"],
        super().validate(
            dataframe=dataframe,
            columns=columns,
            allow_missing_columns=allow_missing_columns,
            allow_superfluous_columns=allow_superfluous_columns,
            drop_superfluous_columns=drop_superfluous_columns,
        ),
    )

SubstationTargetMap

Bases: Model

Maps substations to their primary power column and stores their peak capacity.

This model is used to determine whether to use MW or MVA as the target variable for a given substation, and provides the peak capacity for scaling and validation.

Source code in packages/contracts/src/contracts/data_schemas.py
115
116
117
118
119
120
121
122
123
124
class SubstationTargetMap(pt.Model):
    """Maps substations to their primary power column and stores their peak capacity.

    This model is used to determine whether to use MW or MVA as the target variable
    for a given substation, and provides the peak capacity for scaling and validation.
    """

    substation_number: int = pt.Field(dtype=pl.Int32, unique=True)
    preferred_power_col: PowerColumn = pt.Field(dtype=pl.String)
    peak_capacity_MW_or_MVA: float = pt.Field(dtype=pl.Float32, gt=0)

contracts.hydra_schemas

Hydra configuration schemas for the NGED substation forecast project.

Classes

DataSplitConfig

Bases: BaseModel

Configuration for temporal data splitting.

Source code in packages/contracts/src/contracts/hydra_schemas.py
10
11
12
13
14
15
16
class DataSplitConfig(BaseModel):
    """Configuration for temporal data splitting."""

    train_start: date
    train_end: date
    test_start: date
    test_end: date

ModelConfig

Bases: BaseModel

Configuration for the ML model.

Source code in packages/contracts/src/contracts/hydra_schemas.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class ModelConfig(BaseModel):
    """Configuration for the ML model."""

    power_fcst_model_name: str = Field(
        ...,
        description=(
            "A unique identifier for this model configuration. This name is used to label "
            "predictions in the evaluation results and to identify the model in MLflow. "
            "Users should use this as free text to describe substantial differences between "
            "different versions of the same underlying model (e.g., 'xgboost_baseline', "
            "'xgboost_with_solar_features')."
        ),
    )
    hyperparameters: dict[str, Any] = Field(default_factory=dict)
    required_lookback_days: int = Field(default=21)
    features: ModelFeaturesConfig

    # The latency between the NWP init time and when the NWP is actually downloaded and processed
    # and ready for use.
    nwp_availability_delay_hours: int = Field(default=3)

    # The latency between the telemetry timestamp and when it is actually available for use
    # in our forecasting pipeline.
    telemetry_delay_hours: int = Field(default=24)

    # Maximum number of samples to use for training to prevent OOM errors.
    # If set, the training data will be randomly sampled before collection.
    max_training_samples: int | None = Field(default=None, gt=0)

ModelFeaturesConfig

Bases: BaseModel

Configuration for model features.

Source code in packages/contracts/src/contracts/hydra_schemas.py
26
27
28
29
30
class ModelFeaturesConfig(BaseModel):
    """Configuration for model features."""

    nwps: list[NwpModel] = Field(default_factory=list)
    feature_names: list[str] = Field(default_factory=list)

NwpModel

Bases: str, Enum

Available NWP datasets.

Source code in packages/contracts/src/contracts/hydra_schemas.py
19
20
21
22
23
class NwpModel(str, Enum):
    """Available NWP datasets."""

    ECMWF_ENS_0_25DEG = "ecmwf_ens_0_25deg"
    GFS_0_25DEG = "gfs_0_25deg"

TrainingConfig

Bases: BaseModel

Root configuration object for model training and evaluation.

Source code in packages/contracts/src/contracts/hydra_schemas.py
63
64
65
66
67
class TrainingConfig(BaseModel):
    """Root configuration object for model training and evaluation."""

    data_split: DataSplitConfig
    model: ModelConfig

contracts.settings

Classes

DataQualitySettings

Bases: BaseSettings

Settings for data quality thresholds in substation flow processing.

These thresholds are used to identify problematic telemetry data: - stuck_std_threshold: When the rolling standard deviation falls below this value (across a 48-period/24-hour window), the sensor is likely stuck. We replace such values with null to preserve the temporal grid. A value of 0.01 MW was chosen because substations with normal operation typically have much higher variability.

  • max_mw_threshold: Active power above this value is considered physically unrealistic for primary substations in the NGED portfolio. A threshold of 100.0 MW was chosen because typical primary substations operate in the tens of MW range, and values exceeding 100 MW are extremely rare anomalies.

  • min_mw_threshold: Active power below this value is potentially erroneous (negative values can occur at times of high renewable generation). A threshold of -20.0 MW was chosen to allow for reverse power flow during high renewable generation periods while still catching implausible extreme negative values.

Centralizing these in Settings allows them to be configurable per environment (dev/staging/prod) while preventing logic drift between asset checks and data cleaning steps. All code that references these thresholds should import them from here, not define them locally.

Source code in packages/contracts/src/contracts/settings.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class DataQualitySettings(BaseSettings):
    """Settings for data quality thresholds in substation flow processing.

    These thresholds are used to identify problematic telemetry data:
    - `stuck_std_threshold`: When the rolling standard deviation falls below this value
      (across a 48-period/24-hour window), the sensor is likely stuck. We replace such
      values with null to preserve the temporal grid. A value of 0.01 MW was chosen
      because substations with normal operation typically have much higher variability.

    - `max_mw_threshold`: Active power above this value is considered physically
      unrealistic for primary substations in the NGED portfolio. A threshold of 100.0 MW
      was chosen because typical primary substations operate in the tens of MW range,
      and values exceeding 100 MW are extremely rare anomalies.

    - `min_mw_threshold`: Active power below this value is potentially erroneous
      (negative values can occur at times of high renewable generation). A threshold of
      -20.0 MW was chosen to allow for reverse power flow during high renewable
      generation periods while still catching implausible extreme negative values.

    Centralizing these in Settings allows them to be configurable per environment
    (dev/staging/prod) while preventing logic drift between asset checks and data
    cleaning steps. All code that references these thresholds should import them from
    here, not define them locally.
    """

    stuck_std_threshold: float = 0.01
    stuck_window_periods: int = 48
    max_mw_threshold: float = 100.0
    min_mw_threshold: float = -20.0

Settings

Bases: BaseSettings

Configuration settings for the NGED substation forecast project.

Source code in packages/contracts/src/contracts/settings.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class Settings(BaseSettings):
    """Configuration settings for the NGED substation forecast project."""

    # NGED Connected Data CKAN token
    nged_ckan_token: str = Field(...)

    # NWP Data Settings
    nwp_ensemble_member: int = Field(
        default=0,
        description=(
            "Which NWP ensemble member to use for training the ML model (typically 0, the "
            "control member)."
        ),
    )

    # ML Model Settings
    ml_model_ensemble_size: int = Field(
        default=10,
        description=(
            "Number of ML models to train in an ensemble (e.g. using different random seeds) "
            "to improve robustness and provide uncertainty estimates at inference time."
        ),
    )

    # MLflow Tracking URI
    # We centralize the MLflow tracking URI here to allow for environment-specific
    # configuration (e.g., local SQLite for development, remote server for production).
    # The application entrypoint should read this setting and set the MLFLOW_TRACKING_URI
    # environment variable accordingly, which MLflow will automatically pick up.
    mlflow_tracking_uri: str = Field(
        default="sqlite:///mlflow.db",
        description="MLflow tracking URI.",
    )

    # Data Quality Settings
    data_quality: DataQualitySettings = Field(
        default_factory=DataQualitySettings,
        description="Configurable thresholds for data quality checks.",
    )

    # S3 Storage
    nged_s3_bucket_url: str = Field(...)
    nged_s3_bucket_access_key: str = Field(...)
    nged_s3_bucket_secret: str = Field(...)

    # ECMWF Data Settings
    ecmwf_s3_bucket: str = Field(
        default="dynamical-ecmwf-ifs-ens",
        description="S3 bucket for ECMWF Icechunk store.",
    )
    ecmwf_s3_prefix: str = Field(
        default="ecmwf-ifs-ens-forecast-15-day-0-25-degree/v0.1.0.icechunk/",
        description="S3 prefix for ECMWF Icechunk store.",
    )

    # Paths
    nged_data_path: Path = Path("data/NGED")
    nwp_data_path: Path = Path("data/NWP")
    power_forecasts_data_path: Path = Path("data/power_forecasts")
    forecast_metrics_data_path: Path = Path("data/forecast_metrics")
    trained_ml_model_params_base_path: Path = Path("data/trained_ML_model_params")

    model_config = SettingsConfigDict(
        env_file=PROJECT_ROOT / ".env",
        extra="ignore",
        env_file_encoding="utf-8",
        env_prefix="",
    )

    @field_validator("nged_s3_bucket_url")
    @classmethod
    def validate_url(cls, v: str) -> str:
        """Validate that the S3 bucket URL is a valid URL."""
        url_adapter.validate_python(v)
        return v
Functions
validate_url(v) classmethod

Validate that the S3 bucket URL is a valid URL.

Source code in packages/contracts/src/contracts/settings.py
121
122
123
124
125
126
@field_validator("nged_s3_bucket_url")
@classmethod
def validate_url(cls, v: str) -> str:
    """Validate that the S3 bucket URL is a valid URL."""
    url_adapter.validate_python(v)
    return v

Functions

find_project_root()

Find the project root by looking for uv.lock.

Source code in packages/contracts/src/contracts/settings.py
 9
10
11
12
13
14
15
def find_project_root() -> Path:
    """Find the project root by looking for uv.lock."""
    current = Path.cwd()
    for parent in [current, *current.parents]:
        if (parent / "uv.lock").exists():
            return parent
    return current