Rollout Issues

#1
by motte12 - opened

Hey Prithvi-WxC developers and community,

First off, thanks for sharing this interesting model! I’ve been using it for a university project, and it works pretty well overall, but I’ve run into a bit of a problem. I’ve been running a few rollouts, but I’m seeing what I think are fairly blurred or spatially too homogeneous forecasts for a lot of the variables.

For example, I’ve been trying to reproduce some of the storm tracks from the paper, but the results aren’t showing clear hurricanes at all. The forecasts just seem too uniform across the board. I’ve attached a figure for reference (no particular storm, just a random date)—if you look at the rollout (left subplot), the SLP field seems much too homogeneous to match the storm patterns.

SLP_rollout_comparison.png

Any idea what might be causing this? Could I be missing some important settings or parameters? It’s very well possible I’ve overlooked something obvious. To note, I was able to run the example notebook (PrithviWxC_rollout.ipynb) without issues and reproduced the plot at the bottom of the notebook, so I’m guessing it’s something specific to my setup.

I use the script that is attached at the bottom for the rollout/figure with the following parameters:

  • masking ratio 0.99
  • time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59")
  • lead_time = 6
  • input_time = -6
  • masking_ratio = 0.99
  • positional_encoding = "fourier"

On a side note, rollouts with higher masking ratio seem to work better for me. Could it be that (1 - masking ratio) is the ratio of grid cells that is actually masked, or do I get something wrong here?

In summary I have 3 questions I would appreciate your comments on:

  1. The issue with the blurred/spatially too homogeneous rollouts
  2. Clarification of the masking ratio
  3. The model ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M-rollout should be able to forecast lead times of more than 6 hrs with input time/time delta at 6 hrs, right?

I would appreciate your feedback very much, so thank you in advance! Let me know if I should provide any more info.

Best
Frieder

The script:

import time
import pandas as pd
import torch
import pickle
import yaml
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from huggingface_hub import hf_hub_download, snapshot_download
from PrithviWxC.dataloaders.merra2_rollout import Merra2RolloutDataset
from PrithviWxC.dataloaders.merra2 import (
    input_scalers,
    output_scalers,
    static_input_scalers,
)
from PrithviWxC.model import PrithviWxC
from PrithviWxC.dataloaders.merra2_rollout import preproc
from PrithviWxC.rollout import rollout_iter

def setup_device_and_random_seeds():
    """
    Set the device (CUDA or CPU) and initialize random seeds for reproducibility.
    """
    # Enable oneDNN for better performance on Intel CPUs
    torch.jit.enable_onednn_fusion(True)
    
    # Set CUDA and deterministic settings if CUDA is available
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True
    
    # Set random seeds for reproducibility across various libraries
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Return the device (either CUDA or CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("The device is:", device)
    return device


def load_config(my_dir):
    """
    Load configuration parameters from a YAML file for model setup.
    """
    with open(f'{my_dir}/config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    return config


def create_dataset(my_dir, time_range, lead_time, input_time, masking_ratio, positional_encoding, surface_vars, static_surface_vars, vertical_vars, levels):
    """
    Create the dataset by initializing the Merra2RolloutDataset class.
    """
    # Set directory paths for the data
    surf_dir = Path(f"{my_dir}/merra-2")
    vert_dir = Path(f"{my_dir}/merra-2")
    surf_clim_dir = Path(f"{my_dir}/climatology")
    vert_clim_dir = Path(f"{my_dir}/climatology")
    
    
    
    # Create dataset
    dataset = Merra2RolloutDataset(
        time_range=time_range,
        lead_time=lead_time,
        input_time=input_time,
        data_path_surface=surf_dir,
        data_path_vertical=vert_dir,
        climatology_path_surface=surf_clim_dir,
        climatology_path_vertical=vert_clim_dir,
        surface_vars=surface_vars,
        static_surface_vars=static_surface_vars,
        vertical_vars=vertical_vars,
        levels=levels,
        positional_encoding=positional_encoding,
    )
    
    assert len(dataset) > 0, "There doesn't seem to be any valid data."
    return dataset


def setup_scalers(my_dir, surface_vars, vertical_vars, levels, static_surface_vars):
    """
    Load the necessary scalers for input/output normalization.
    """
    # Set file paths for scalers
    surf_in_scal_path = Path(f"{my_dir}/climatology/musigma_surface.nc")
    vert_in_scal_path = Path(f"{my_dir}/climatology/musigma_vertical.nc")
    surf_out_scal_path = Path(f"{my_dir}/climatology/anomaly_variance_surface.nc")
    vert_out_scal_path = Path(f"{my_dir}/climatology/anomaly_variance_vertical.nc")
    
    # Load input and output scalers
    in_mu, in_sig = input_scalers(surface_vars, vertical_vars, levels, surf_in_scal_path, vert_in_scal_path)
    output_sig = output_scalers(surface_vars, vertical_vars, levels, surf_out_scal_path, vert_out_scal_path)
    static_mu, static_sig = static_input_scalers(surf_in_scal_path, static_surface_vars)
    
    return in_mu, in_sig, output_sig, static_mu, static_sig


def load_model(my_dir, config, in_mu, in_sig, static_mu, static_sig, output_sig, positional_encoding, masking_ratio):
    """
    Initialize and load the model.
    """
    # Load model configuration and initialize the model
    model = PrithviWxC(
        in_channels=config['params']['in_channels'],
        input_size_time=config['params']['input_size_time'],
        in_channels_static=config['params']['in_channels_static'],
        input_scalers_mu=in_mu,
        input_scalers_sigma=in_sig,
        input_scalers_epsilon=config['params']['input_scalers_epsilon'],
        static_input_scalers_mu=static_mu,
        static_input_scalers_sigma=static_sig,
        static_input_scalers_epsilon=config['params']['static_input_scalers_epsilon'],
        output_scalers=output_sig**0.5,
        n_lats_px=config['params']['n_lats_px'],
        n_lons_px=config['params']['n_lons_px'],
        patch_size_px=config['params']['patch_size_px'],
        mask_unit_size_px=config['params']['mask_unit_size_px'],
        mask_ratio_inputs=masking_ratio,
        embed_dim=config['params']['embed_dim'],
        n_blocks_encoder=config['params']['n_blocks_encoder'],
        n_blocks_decoder=config['params']['n_blocks_decoder'],
        mlp_multiplier=config['params']['mlp_multiplier'],
        n_heads=config['params']['n_heads'],
        dropout=config['params']['dropout'],
        drop_path=config['params']['drop_path'],
        parameter_dropout=config['params']['parameter_dropout'],
        residual="climate",
        masking_mode="local",
        decoder_shifting=True,
        positional_encoding=positional_encoding,
        checkpoint_encoder=[],
        checkpoint_decoder=[],
    )

    # Load the model's pretrained weights
    weights_path = Path(f"{my_dir}/weights/prithvi.wxc.rollout.2300m.v1.pt")
    state_dict = torch.load(weights_path, weights_only=False)
    if "model_state" in state_dict:
        state_dict = state_dict["model_state"]
    model.load_state_dict(state_dict, strict=True)
    
    return model


def save_data_to_disk(batch, y_targets, target_times, out, all_outputs_cpu, my_dir, lead_time, input_time, masking_ratio, time_range):
    """
    Save the batch, targets, and model outputs to disk as pickle files.
    """
    # Define file paths for saving outputs
    output_dir = f'/work/fl53wumy-scads_intern_restored/fl53wumy-internship_scads-1734574817/prithvi_output/output_rollout/'

    # Save rollout targets (truth values)
    with open(f'{output_dir}simple_rollout_y_targets_test_leadtime{lead_time}_input_time{input_time}_maskingrat{masking_ratio}_{time_range[0][0:10]}-{time_range[1][0:10]}.pkl', 'wb') as pickle_file:
        pickle.dump(y_targets, pickle_file)

    # Save rollout target times
    with open(f'{output_dir}simple_rollout_y_target_times_list_test_leadtime{lead_time}_input_time{input_time}_maskingrat{masking_ratio}_{time_range[0][0:10]}-{time_range[1][0:10]}.pkl', 'wb') as pickle_file:
        pickle.dump(target_times, pickle_file)

    # Save all outputs (forecast with intermediate steps)
    with open(f'{output_dir}simple_rollout_leadtime{lead_time}_input_time{input_time}_maskingrat{masking_ratio}_{time_range[0][0:10]}-{time_range[1][0:10]}.pkl', 'wb') as pickle_file:
        pickle.dump(all_outputs_cpu, pickle_file)


def main():
    start_time = time.time()

    # Define the base directory and time range for the dataset
    my_dir = "/work/fl53wumy-scads_intern_restored/fl53wumy-internship_scads-1734574817/data_roullout"
    time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59")

    # Experiment parameters
    lead_time = 6
    input_time = -6
    masking_ratio = 0.99
    positional_encoding = "fourier"
    padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}


    # Define surface and vertical variables
    surface_vars = [
        "EFLUX", "GWETROOT", "HFLUX", "LAI", "LWGAB", "LWGEM", "LWTUP", 
        "PS", "QV2M", "SLP", "SWGNT", "SWTNT", "T2M", "TQI", "TQL", "TQV", 
        "TS", "U10M", "V10M", "Z0M"
    ]
    static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
    vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
    levels = [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 51.0, 53.0, 56.0, 63.0, 68.0, 71.0, 72.0]

    # Step 1: Set up device (CPU or GPU) and initialize random seeds
    device = setup_device_and_random_seeds()
    
    # Step 2: Load configuration settings from YAML file
    config = load_config(my_dir)

    # Step 3: Create the dataset
    dataset = create_dataset(my_dir, time_range, lead_time, input_time, masking_ratio, positional_encoding, surface_vars, static_surface_vars, vertical_vars, levels)

    # Step 4: Set up input/output scalers for data normalization
    in_mu, in_sig, output_sig, static_mu, static_sig = setup_scalers(my_dir, surface_vars, vertical_vars, levels, static_surface_vars)

    # Step 5: Load the model and its weights, move to device
    model = load_model(my_dir, config, in_mu, in_sig, static_mu, static_sig, output_sig, positional_encoding, masking_ratio)
    if (hasattr(model, "device") and model.device != device) or not hasattr(
        model, "device"
    ):
        model = model.to(device)

    # Step 6: Prepare data and move it to GPU if available
    data, target_times = next(iter(dataset))
    batch = preproc([data], padding)  # Add appropriate padding dictionary here
    y_targets = batch['ys']

    # Transfer the batch to GPU
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device)

    # Step 7: Run the model and generate predictions (rollout)
    with torch.no_grad():
        model.eval()
        out, batch_returned, all_outputs = rollout_iter(dataset.nsteps, model, batch)

    # Step 8: Transfer all outputs back to CPU
    all_outputs_cpu = [out.cpu() for out in all_outputs]

    # Step 9: Save the results to disk
    save_data_to_disk(batch, y_targets, target_times, out, all_outputs_cpu, my_dir, lead_time, input_time, masking_ratio, time_range)

    # Record the end time and calculate elapsed time
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Execution time: {elapsed_time:.4f} seconds")


if __name__ == "__main__":
    main()

Sign up or log in to comment