Compare Data Access Times for AWS and Glade Zarr Stores¶
[ ]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
import s3fs
Use Dask to Speed up Computations¶
We are testing data transfer rate, so we use many workers to make sure data bandwidth is as saturated as possible.
It’s a simple plotting task, so we don’t ask for much walltime or memory per worker, to get workers more quickly from the scheduler.
[ ]:
import dask
from ncar_jobqueue import NCARCluster
num_jobs = 30
walltime = "1:00:00"
memory='6GB'
cluster = NCARCluster(cores=num_jobs, processes=1, memory=memory, walltime=walltime)
cluster.scale(jobs=num_jobs)
from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster
Functions for Getting Data¶
[ ]:
def get_aws_store(store_name):
"""Given a store name, open that store on Amazon AWS and return an XArray dataset object."""
fs = s3fs.S3FileSystem(anon=True)
root = "s3://ncar-na-cordex"
data_frequency = "day"
full_path = f"{root}/{data_frequency}/{store_name}"
store = s3fs.S3Map(root=full_path, s3=fs)
ds = xr.open_zarr(store, consolidated=True)
return ds
[ ]:
def get_glade_store(store_name):
"""Given a store name, open that store on NCAR Glade and return an XArray dataset object."""
root = "/glade/scratch/bonnland/na-cordex/zarr-publish"
full_path = f"{root}/{store_name}"
ds = xr.open_zarr(full_path, consolidated=True)
return ds
[ ]:
Define Plot Functions¶
Create Single Map Plot (Helper Function)¶
[ ]:
def plotMap(ax, map_slice, date_object=None, member_id=None):
'''Create a map plot on the given axes, with min/max as text'''
ax.imshow(map_slice, origin='lower')
minval = map_slice.min(dim = ['lat', 'lon'])
maxval = map_slice.max(dim = ['lat', 'lon'])
# Format values to have at least 4 digits of precision.
ax.text(0.01, 0.03, "%4g" % minval, transform=ax.transAxes, fontsize=12)
ax.text(0.99, 0.03, "%4g" % maxval, transform=ax.transAxes, fontsize=12, horizontalalignment='right')
ax.set_xticks([])
ax.set_yticks([])
if date_object:
ax.set_title(date_object.values.astype(str)[:10], fontsize=12)
if member_id:
ax.set_ylabel(member_id, fontsize=12)
return ax
Function Producing Maps of First, Middle, Last Timesteps¶
[ ]:
def getValidDateIndexes(member_slice):
'''Search for the first and last dates with finite values.'''
min_values = member_slice.min(dim = ['lat', 'lon'])
is_finite = np.isfinite(min_values)
finite_indexes = np.where(is_finite)
start_index = finite_indexes[0][0]
end_index = finite_indexes[0][-1]
#print(f'start ={start_index}, end={end_index}')
return start_index, end_index
def plot_first_mid_last(ds, data_var, store_name, plotdir):
# Generate plot.
#
# With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering on Glade.
#
member_names = ds.coords['member_id'].values[0:4]
numEnsembleMembers = member_names.size
numPlotsPerPage = 4
numPlotCols = 3
figWidth = 18
figHeight = 12 #20
fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)
for index in np.arange(numEnsembleMembers):
mem_id = member_names[index]
data_slice = ds[data_var].sel(member_id=mem_id)
start_index, end_index = getValidDateIndexes(data_slice)
midDateIndex = np.floor(len(ds.time) / 2).astype(int)
startDate = ds.time[start_index]
first_step = data_slice.sel(time=startDate)
ax = axs[index, 0]
plotMap(ax, first_step, startDate, mem_id)
midDate = ds.time[midDateIndex]
mid_step = data_slice.sel(time=midDate)
ax = axs[index, 1]
plotMap(ax, mid_step, midDate)
endDate = ds.time[end_index]
last_step = data_slice.sel(time=endDate)
ax = axs[index, 2]
plotMap(ax, last_step, endDate)
plt.suptitle(store_name, fontsize=20)
plt.show()
Create Time Series Plots over Multiple Pages¶
These also mark the locations of missing values.
[ ]:
def plot_timeseries(ds, data_var, store_name, plotdir):
# Generate plot.
#
# With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering on Glade.
#
member_names = ds.coords['member_id'].values[0:4]
numEnsembleMembers = member_names.size
numPages = 1
numPlotsPerPage = 4
numPlotCols = 1
figWidth = 25
figHeight = 20
linewidth = 0.5
for pageNum in range(numPages):
# Plot the aggregate statistics across time.
fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight))
for index in np.arange(numEnsembleMembers):
mem_id = member_names[index]
data_slice = ds[data_var].sel(member_id=mem_id)
unit_string = ds[data_var].attrs['units']
min_vals = data_slice.min(dim = ['lat', 'lon'])
max_vals = data_slice.max(dim = ['lat', 'lon'])
mean_vals = data_slice.mean(dim = ['lat', 'lon'])
std_vals = data_slice.std(dim = ['lat', 'lon'])
nan_indexes = np.isnan(min_vals)
nan_times = ds.time[nan_indexes]
axs[index].plot(ds.time, min_vals, linewidth=linewidth, label='min')
axs[index].plot(ds.time, max_vals, linewidth=linewidth, label='max')
axs[index].plot(ds.time, mean_vals, linewidth=linewidth, label='mean')
axs[index].plot(ds.time, std_vals, linewidth=linewidth, label='std')
ymin, ymax = axs[index].get_ylim()
rug_y = ymin + 0.01*(ymax-ymin)
axs[index].plot(nan_times, [rug_y]*len(nan_times), '|', color='m', label='isnan')
axs[index].set_title(mem_id, fontsize=20)
axs[index].legend(loc='upper right')
axs[index].set_ylabel(unit_string)
plt.suptitle(store_name, fontsize=25)
plt.tight_layout(pad=10.2, w_pad=3.5, h_pad=3.5)
plt.show()
Load Zarr Store from Different Sources and Record Plotting Times¶
[ ]:
store_name = 'prec.rcp85.day.NAM-44i.raw.zarr'
plotdir = '.'
data_var = 'prec'
[ ]:
Plot First, Middle, and Final Timesteps from Glade Store¶
[ ]:
%%time
# Plot using a Zarr Store on Glade
ds = get_glade_store(store_name)
plot_first_mid_last(ds, data_var, store_name, plotdir)
Plot First, Middle, and Final Timesteps from AWS Store¶
[ ]:
%%time
# Plot using a Zarr Store on AWS
ds = get_aws_store(store_name)
plot_first_mid_last(ds, data_var, store_name, plotdir)
[ ]:
[ ]:
Plot Time Series from Glade Store¶
[ ]:
%%time
# Plot using a Zarr Store on Glade
ds = get_glade_store(store_name)
plot_timeseries(ds, data_var, store_name, plotdir)
Plot Time Series from AWS Store¶
[ ]:
%%time
# Plot using a Zarr Store on AWS
ds = get_aws_store(store_name)
plot_timeseries(ds, data_var, store_name, plotdir)
[ ]:
Release the workers.¶
[ ]:
!date
[ ]:
cluster.close()
[ ]: