Compare Data Access Times for AWS and Glade Zarr Stores

[ ]:
import xarray as xr
import numpy as np

import matplotlib.pyplot as plt
from pathlib import Path
import os
import s3fs

Use Dask to Speed up Computations

  • We are testing data transfer rate, so we use many workers to make sure data bandwidth is as saturated as possible.

  • It’s a simple plotting task, so we don’t ask for much walltime or memory per worker, to get workers more quickly from the scheduler.

[ ]:
import dask
from ncar_jobqueue import NCARCluster

num_jobs = 30
walltime = "1:00:00"
memory='6GB'
cluster = NCARCluster(cores=num_jobs, processes=1, memory=memory, walltime=walltime)
cluster.scale(jobs=num_jobs)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

Functions for Getting Data

[ ]:
def get_aws_store(store_name):
    """Given a store name, open that store on Amazon AWS and return an XArray dataset object."""

    fs = s3fs.S3FileSystem(anon=True)

    root = "s3://ncar-na-cordex"
    data_frequency = "day"
    full_path = f"{root}/{data_frequency}/{store_name}"
    store = s3fs.S3Map(root=full_path, s3=fs)

    ds = xr.open_zarr(store, consolidated=True)
    return ds
[ ]:
def get_glade_store(store_name):
    """Given a store name, open that store on NCAR Glade and return an XArray dataset object."""

    root = "/glade/scratch/bonnland/na-cordex/zarr-publish"
    full_path = f"{root}/{store_name}"

    ds = xr.open_zarr(full_path, consolidated=True)
    return ds
[ ]:

Define Plot Functions

Create Single Map Plot (Helper Function)

[ ]:
def plotMap(ax, map_slice, date_object=None, member_id=None):
    '''Create a map plot on the given axes, with min/max as text'''

    ax.imshow(map_slice, origin='lower')

    minval = map_slice.min(dim = ['lat', 'lon'])
    maxval = map_slice.max(dim = ['lat', 'lon'])

    # Format values to have at least 4 digits of precision.
    ax.text(0.01, 0.03, "%4g" % minval, transform=ax.transAxes, fontsize=12)
    ax.text(0.99, 0.03, "%4g" % maxval, transform=ax.transAxes, fontsize=12, horizontalalignment='right')
    ax.set_xticks([])
    ax.set_yticks([])

    if date_object:
        ax.set_title(date_object.values.astype(str)[:10], fontsize=12)

    if member_id:
        ax.set_ylabel(member_id, fontsize=12)

    return ax

Function Producing Maps of First, Middle, Last Timesteps

[ ]:
def getValidDateIndexes(member_slice):
    '''Search for the first and last dates with finite values.'''
    min_values = member_slice.min(dim = ['lat', 'lon'])
    is_finite = np.isfinite(min_values)
    finite_indexes = np.where(is_finite)
    start_index = finite_indexes[0][0]
    end_index = finite_indexes[0][-1]
    #print(f'start ={start_index}, end={end_index}')
    return start_index, end_index


def plot_first_mid_last(ds, data_var, store_name, plotdir):
    # Generate plot.
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering on Glade.
    #
    member_names = ds.coords['member_id'].values[0:4]

    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPlotCols = 3

    figWidth = 18
    figHeight = 12 #20

    fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

    for index in np.arange(numEnsembleMembers):
        mem_id = member_names[index]
        data_slice = ds[data_var].sel(member_id=mem_id)

        start_index, end_index = getValidDateIndexes(data_slice)
        midDateIndex = np.floor(len(ds.time) / 2).astype(int)

        startDate = ds.time[start_index]
        first_step = data_slice.sel(time=startDate)
        ax = axs[index, 0]
        plotMap(ax, first_step, startDate, mem_id)

        midDate = ds.time[midDateIndex]
        mid_step = data_slice.sel(time=midDate)
        ax = axs[index, 1]
        plotMap(ax, mid_step, midDate)

        endDate = ds.time[end_index]
        last_step = data_slice.sel(time=endDate)
        ax = axs[index, 2]
        plotMap(ax, last_step, endDate)

        plt.suptitle(store_name, fontsize=20)

    plt.show()

Create Time Series Plots over Multiple Pages

These also mark the locations of missing values.

[ ]:
def plot_timeseries(ds, data_var, store_name, plotdir):
    # Generate plot.
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering on Glade.
    #
    member_names = ds.coords['member_id'].values[0:4]
    numEnsembleMembers = member_names.size

    numPages = 1
    numPlotsPerPage = 4
    numPlotCols = 1

    figWidth = 25
    figHeight = 20

    linewidth = 0.5


    for pageNum in range(numPages):

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight))

        for index in np.arange(numEnsembleMembers):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)
            unit_string = ds[data_var].attrs['units']

            min_vals = data_slice.min(dim = ['lat', 'lon'])
            max_vals = data_slice.max(dim = ['lat', 'lon'])
            mean_vals = data_slice.mean(dim = ['lat', 'lon'])
            std_vals = data_slice.std(dim = ['lat', 'lon'])

            nan_indexes = np.isnan(min_vals)
            nan_times = ds.time[nan_indexes]

            axs[index].plot(ds.time, min_vals, linewidth=linewidth, label='min')
            axs[index].plot(ds.time, max_vals, linewidth=linewidth, label='max')
            axs[index].plot(ds.time, mean_vals, linewidth=linewidth, label='mean')
            axs[index].plot(ds.time, std_vals, linewidth=linewidth, label='std')

            ymin, ymax = axs[index].get_ylim()
            rug_y = ymin + 0.01*(ymax-ymin)
            axs[index].plot(nan_times, [rug_y]*len(nan_times), '|', color='m', label='isnan')
            axs[index].set_title(mem_id, fontsize=20)
            axs[index].legend(loc='upper right')
            axs[index].set_ylabel(unit_string)

        plt.suptitle(store_name, fontsize=25)
        plt.tight_layout(pad=10.2, w_pad=3.5, h_pad=3.5)

    plt.show()

Load Zarr Store from Different Sources and Record Plotting Times

[ ]:
store_name = 'prec.rcp85.day.NAM-44i.raw.zarr'
plotdir = '.'
data_var = 'prec'
[ ]:

Plot First, Middle, and Final Timesteps from Glade Store

[ ]:
%%time

# Plot using a Zarr Store on Glade
ds = get_glade_store(store_name)
plot_first_mid_last(ds, data_var, store_name, plotdir)

Plot First, Middle, and Final Timesteps from AWS Store

[ ]:
%%time

# Plot using a Zarr Store on AWS
ds = get_aws_store(store_name)
plot_first_mid_last(ds, data_var, store_name, plotdir)
[ ]:

[ ]:

Plot Time Series from Glade Store

[ ]:
%%time

# Plot using a Zarr Store on Glade
ds = get_glade_store(store_name)
plot_timeseries(ds, data_var, store_name, plotdir)

Plot Time Series from AWS Store

[ ]:
%%time

# Plot using a Zarr Store on AWS
ds = get_aws_store(store_name)
plot_timeseries(ds, data_var, store_name, plotdir)
[ ]:

Release the workers.

[ ]:
!date
[ ]:
cluster.close()
[ ]: