Rollout Issues
Hey Prithvi-WxC developers and community,
First off, thanks for sharing this interesting model! I’ve been using it for a university project, and it works pretty well overall, but I’ve run into a bit of a problem. I’ve been running a few rollouts, but I’m seeing what I think are fairly blurred or spatially too homogeneous forecasts for a lot of the variables.
For example, I’ve been trying to reproduce some of the storm tracks from the paper, but the results aren’t showing clear hurricanes at all. The forecasts just seem too uniform across the board. I’ve attached a figure for reference (no particular storm, just a random date)—if you look at the rollout (left subplot), the SLP field seems much too homogeneous to match the storm patterns.
Any idea what might be causing this? Could I be missing some important settings or parameters? It’s very well possible I’ve overlooked something obvious. To note, I was able to run the example notebook (PrithviWxC_rollout.ipynb) without issues and reproduced the plot at the bottom of the notebook, so I’m guessing it’s something specific to my setup.
I use the script that is attached at the bottom for the rollout/figure with the following parameters:
- masking ratio 0.99
- time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59")
- lead_time = 6
- input_time = -6
- masking_ratio = 0.99
- positional_encoding = "fourier"
On a side note, rollouts with higher masking ratio seem to work better for me. Could it be that (1 - masking ratio) is the ratio of grid cells that is actually masked, or do I get something wrong here?
In summary I have 3 questions I would appreciate your comments on:
- The issue with the blurred/spatially too homogeneous rollouts
- Clarification of the masking ratio
- The model ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M-rollout should be able to forecast lead times of more than 6 hrs with input time/time delta at 6 hrs, right?
I would appreciate your feedback very much, so thank you in advance! Let me know if I should provide any more info.
Best
Frieder
The script:
import time
import pandas as pd
import torch
import pickle
import yaml
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download, snapshot_download
from PrithviWxC.dataloaders.merra2_rollout import Merra2RolloutDataset
from PrithviWxC.dataloaders.merra2 import (
input_scalers,
output_scalers,
static_input_scalers,
)
from PrithviWxC.model import PrithviWxC
from PrithviWxC.dataloaders.merra2_rollout import preproc
from PrithviWxC.rollout import rollout_iter
def setup_device_and_random_seeds():
"""
Set the device (CUDA or CPU) and initialize random seeds for reproducibility.
"""
# Enable oneDNN for better performance on Intel CPUs
torch.jit.enable_onednn_fusion(True)
# Set CUDA and deterministic settings if CUDA is available
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
# Set random seeds for reproducibility across various libraries
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)
# Return the device (either CUDA or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("The device is:", device)
return device
def load_config(my_dir):
"""
Load configuration parameters from a YAML file for model setup.
"""
with open(f'{my_dir}/config.yaml', 'r') as f:
config = yaml.safe_load(f)
return config
def create_dataset(my_dir, time_range, lead_time, input_time, masking_ratio, positional_encoding, surface_vars, static_surface_vars, vertical_vars, levels):
"""
Create the dataset by initializing the Merra2RolloutDataset class.
"""
# Set directory paths for the data
surf_dir = Path(f"{my_dir}/merra-2")
vert_dir = Path(f"{my_dir}/merra-2")
surf_clim_dir = Path(f"{my_dir}/climatology")
vert_clim_dir = Path(f"{my_dir}/climatology")
# Create dataset
dataset = Merra2RolloutDataset(
time_range=time_range,
lead_time=lead_time,
input_time=input_time,
data_path_surface=surf_dir,
data_path_vertical=vert_dir,
climatology_path_surface=surf_clim_dir,
climatology_path_vertical=vert_clim_dir,
surface_vars=surface_vars,
static_surface_vars=static_surface_vars,
vertical_vars=vertical_vars,
levels=levels,
positional_encoding=positional_encoding,
)
assert len(dataset) > 0, "There doesn't seem to be any valid data."
return dataset
def setup_scalers(my_dir, surface_vars, vertical_vars, levels, static_surface_vars):
"""
Load the necessary scalers for input/output normalization.
"""
# Set file paths for scalers
surf_in_scal_path = Path(f"{my_dir}/climatology/musigma_surface.nc")
vert_in_scal_path = Path(f"{my_dir}/climatology/musigma_vertical.nc")
surf_out_scal_path = Path(f"{my_dir}/climatology/anomaly_variance_surface.nc")
vert_out_scal_path = Path(f"{my_dir}/climatology/anomaly_variance_vertical.nc")
# Load input and output scalers
in_mu, in_sig = input_scalers(surface_vars, vertical_vars, levels, surf_in_scal_path, vert_in_scal_path)
output_sig = output_scalers(surface_vars, vertical_vars, levels, surf_out_scal_path, vert_out_scal_path)
static_mu, static_sig = static_input_scalers(surf_in_scal_path, static_surface_vars)
return in_mu, in_sig, output_sig, static_mu, static_sig
def load_model(my_dir, config, in_mu, in_sig, static_mu, static_sig, output_sig, positional_encoding, masking_ratio):
"""
Initialize and load the model.
"""
# Load model configuration and initialize the model
model = PrithviWxC(
in_channels=config['params']['in_channels'],
input_size_time=config['params']['input_size_time'],
in_channels_static=config['params']['in_channels_static'],
input_scalers_mu=in_mu,
input_scalers_sigma=in_sig,
input_scalers_epsilon=config['params']['input_scalers_epsilon'],
static_input_scalers_mu=static_mu,
static_input_scalers_sigma=static_sig,
static_input_scalers_epsilon=config['params']['static_input_scalers_epsilon'],
output_scalers=output_sig**0.5,
n_lats_px=config['params']['n_lats_px'],
n_lons_px=config['params']['n_lons_px'],
patch_size_px=config['params']['patch_size_px'],
mask_unit_size_px=config['params']['mask_unit_size_px'],
mask_ratio_inputs=masking_ratio,
embed_dim=config['params']['embed_dim'],
n_blocks_encoder=config['params']['n_blocks_encoder'],
n_blocks_decoder=config['params']['n_blocks_decoder'],
mlp_multiplier=config['params']['mlp_multiplier'],
n_heads=config['params']['n_heads'],
dropout=config['params']['dropout'],
drop_path=config['params']['drop_path'],
parameter_dropout=config['params']['parameter_dropout'],
residual="climate",
masking_mode="local",
decoder_shifting=True,
positional_encoding=positional_encoding,
checkpoint_encoder=[],
checkpoint_decoder=[],
)
# Load the model's pretrained weights
weights_path = Path(f"{my_dir}/weights/prithvi.wxc.rollout.2300m.v1.pt")
state_dict = torch.load(weights_path, weights_only=False)
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)
return model
def save_data_to_disk(batch, y_targets, target_times, out, all_outputs_cpu, my_dir, lead_time, input_time, masking_ratio, time_range):
"""
Save the batch, targets, and model outputs to disk as pickle files.
"""
# Define file paths for saving outputs
output_dir = f'/work/fl53wumy-scads_intern_restored/fl53wumy-internship_scads-1734574817/prithvi_output/output_rollout/'
# Save rollout targets (truth values)
with open(f'{output_dir}simple_rollout_y_targets_test_leadtime{lead_time}_input_time{input_time}_maskingrat{masking_ratio}_{time_range[0][0:10]}-{time_range[1][0:10]}.pkl', 'wb') as pickle_file:
pickle.dump(y_targets, pickle_file)
# Save rollout target times
with open(f'{output_dir}simple_rollout_y_target_times_list_test_leadtime{lead_time}_input_time{input_time}_maskingrat{masking_ratio}_{time_range[0][0:10]}-{time_range[1][0:10]}.pkl', 'wb') as pickle_file:
pickle.dump(target_times, pickle_file)
# Save all outputs (forecast with intermediate steps)
with open(f'{output_dir}simple_rollout_leadtime{lead_time}_input_time{input_time}_maskingrat{masking_ratio}_{time_range[0][0:10]}-{time_range[1][0:10]}.pkl', 'wb') as pickle_file:
pickle.dump(all_outputs_cpu, pickle_file)
def main():
start_time = time.time()
# Define the base directory and time range for the dataset
my_dir = "/work/fl53wumy-scads_intern_restored/fl53wumy-internship_scads-1734574817/data_roullout"
time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59")
# Experiment parameters
lead_time = 6
input_time = -6
masking_ratio = 0.99
positional_encoding = "fourier"
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
# Define surface and vertical variables
surface_vars = [
"EFLUX", "GWETROOT", "HFLUX", "LAI", "LWGAB", "LWGEM", "LWTUP",
"PS", "QV2M", "SLP", "SWGNT", "SWTNT", "T2M", "TQI", "TQL", "TQV",
"TS", "U10M", "V10M", "Z0M"
]
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
levels = [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 51.0, 53.0, 56.0, 63.0, 68.0, 71.0, 72.0]
# Step 1: Set up device (CPU or GPU) and initialize random seeds
device = setup_device_and_random_seeds()
# Step 2: Load configuration settings from YAML file
config = load_config(my_dir)
# Step 3: Create the dataset
dataset = create_dataset(my_dir, time_range, lead_time, input_time, masking_ratio, positional_encoding, surface_vars, static_surface_vars, vertical_vars, levels)
# Step 4: Set up input/output scalers for data normalization
in_mu, in_sig, output_sig, static_mu, static_sig = setup_scalers(my_dir, surface_vars, vertical_vars, levels, static_surface_vars)
# Step 5: Load the model and its weights, move to device
model = load_model(my_dir, config, in_mu, in_sig, static_mu, static_sig, output_sig, positional_encoding, masking_ratio)
if (hasattr(model, "device") and model.device != device) or not hasattr(
model, "device"
):
model = model.to(device)
# Step 6: Prepare data and move it to GPU if available
data, target_times = next(iter(dataset))
batch = preproc([data], padding) # Add appropriate padding dictionary here
y_targets = batch['ys']
# Transfer the batch to GPU
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
# Step 7: Run the model and generate predictions (rollout)
with torch.no_grad():
model.eval()
out, batch_returned, all_outputs = rollout_iter(dataset.nsteps, model, batch)
# Step 8: Transfer all outputs back to CPU
all_outputs_cpu = [out.cpu() for out in all_outputs]
# Step 9: Save the results to disk
save_data_to_disk(batch, y_targets, target_times, out, all_outputs_cpu, my_dir, lead_time, input_time, masking_ratio, time_range)
# Record the end time and calculate elapsed time
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.4f} seconds")
if __name__ == "__main__":
main()