Skip to content

Commit

Permalink
Merge pull request #29 from ACCESS-NRI/flick/edit_enso_diagnostic_script
Browse files Browse the repository at this point in the history
Restructure enso diagnostic script
  • Loading branch information
flicj191 authored Nov 28, 2024
2 parents 6af6953 + 2fa035a commit 87b410f
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 131 deletions.
5 changes: 5 additions & 0 deletions recipe_diagnostics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Recipes

- climatology_metrics.yml
- climatology_diaglevel2.yml
- enso_metrics.yml

Diagnostics are stored in *diagnostic_scripts/*

Expand Down Expand Up @@ -53,6 +54,8 @@ Script: **matrix.py**


* HadISST
* ERA-Interim
* GPCP-SG


### References
Expand All @@ -63,3 +66,5 @@ Script: **matrix.py**

<p align="center"><img src="figures/plot_matrix.png" alt="portrait plot" width="60%"/></p>

<p align="center"><img src="figures/plot_matrix_enso.png" alt="portrait plot" width="60%"/></p>

203 changes: 87 additions & 116 deletions recipe_diagnostics/diagnostic_scripts/enso_diag1metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ def plot_level1(input_data, metricval, y_label, title, dtls): #input data is 2 -
# model first
plt.plot(*input_data[0], label=dtls[0])
plt.plot(*input_data[1], label=f'ref: {dtls[1]}', color='black')
val_type = 'RMSE'
plt.text(0.5, 0.95, f"RMSE: {metricval:.2f}", fontsize=12, ha='center', transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

else:
plt.scatter(range(len(input_data)), input_data, c=['black','blue'], marker='D')
# obs first
plt.xlim(-0.5,2)#range(-1,3,1)) #['model','obs']
plt.xticks([])
plt.text(0.75,0.85, f'* {dtls[0]}', color='blue',transform=plt.gca().transAxes)
plt.text(0.75,0.8, f'* ref: {dtls[1]}', color='black',transform=plt.gca().transAxes)
val_type = 'metric(%)'
plt.text(0.75,0.95, f'* {dtls[0]}', color='blue', transform=plt.gca().transAxes)
plt.text(0.75,0.9, f'* ref: {dtls[1]}', color='black', transform=plt.gca().transAxes)
plt.text(0.75, 0.8, f"metric(%): {metricval:.2f}", fontsize=12, transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

plt.title(title) # metric name
plt.legend()
plt.grid(linestyle='--')
plt.ylabel(y_label) #param
# metric type: RMSE or %
plt.text(0.5, 0.95, f"{val_type}: {metricval:.2f}", fontsize=12, ha='center', transform=plt.gca().transAxes,
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

if title == 'ENSO pattern': # if array, not scatter
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(format_longitude))
Expand All @@ -67,31 +67,15 @@ def plot_level1(input_data, metricval, y_label, title, dtls): #input data is 2 -

return figure

def lin_regress(cube_ssta, cube_nino34): #1d
def lin_regress(cube_ssta, cube_nino34): #1d pattern
slope_ls = []
for lon_slice in cube_ssta.slices(['time']):
res = linregress(cube_nino34.data, lon_slice.data)
slope_ls.append(res[0])

return cube_ssta.coord('longitude').points, slope_ls

def pattern_09(input_pair, dt_ls): #['tos_patdiv1':, 'tos_pat2':] input_pair

# obs first
mod_ssta = input_pair[1]['tos_pat2']
mod_nino34 = input_pair[1]['tos_patdiv1']
reg_mod = lin_regress(mod_ssta, mod_nino34)
reg_obs = lin_regress(input_pair[0]['tos_pat2'], input_pair[0]['tos_patdiv1'])

rmse = np.sqrt(np.mean((np.array(reg_obs[1]) - np.array(reg_mod[1])) ** 2))
#save data? reg_mod as cube?

# plot functions? title
fig = plot_level1([reg_mod,reg_obs], rmse, 'reg(ENSO SSTA, SSTA)', 'ENSO pattern', dt_ls)

return rmse, fig

def sst_regressed(n34_cube): #dict
def sst_regressed(n34_cube): #for lifecycle
# params cubes,
n34_dec = extract_month(n34_cube, 12)
n34_dec = xr.DataArray.from_iris(n34_dec)
Expand All @@ -115,69 +99,83 @@ def sst_regressed(n34_cube): #dict
slope = scp.LinReg(n34_dec_ct.values, n34_selected).slope
return slope

def lifecycle_10(input_pair, dtls): #variable_group
# inputs pairs of model and obs
## metric computation - rmse of slopes
logger.info(input_pair[1]['tos_lifdur1'])
model = sst_regressed(input_pair[1]['tos_lifdur1']) #n34_cube
obs = sst_regressed(input_pair[0]['tos_lifdur1'])
rmse = np.sqrt(np.mean((obs - model) ** 2))
months = np.arange(1, 73) - 36 #build tuples?
#save data? slope as cube?

# plot function #need xticks, labels as dict/ls
fig = plot_level1([ (months,model),(months,obs)], rmse, 'Degree C / C','ENSO lifecycle', dtls)
return rmse, fig

def amplitude_11(input_pair, dtls, var_group):
metric = [input_pair[1][var_group].data.item(),input_pair[0][var_group].data.item()]
val = compute(metric[1],metric[0])
#plt.scatter(range(len(metric)), metric, c=['blue','black'], marker='D')
fig = plot_level1(metric, val, 'SSTA std (°C)','ENSO amplitude', dtls)
return val, fig
def seasonality_12(input_pair, dtls, var_group):
# cubes, season, climate
metric = []
for ds in input_pair: #obs 0, mod 1
preproc = {}
for seas in ['NDJ','MAM']:
cube = extract_season(ds[var_group], seas)
cube = climate_statistics(cube, operator="std_dev", period="full")
preproc[seas] = cube.data

ds_val = preproc['NDJ']/preproc['MAM']
metric.append(ds_val)

val = compute(metric[1],metric[0])
#plt.scatter(range(len(metric)), metric, c=['blue','black'], marker='D')
fig = plot_level1(metric, val, 'SSTA std (NDJ/MAM)(°C/°C)','ENSO seasonality', dtls)
return val, fig
def asymmetry_13(input_pair, dtls, var_group):
model_skew = skew(input_pair[1][var_group].data, axis=0)
obs_skew = skew(input_pair[0][var_group].data, axis=0)
metric = [model_skew,obs_skew]

val = compute(metric[1],metric[0])
#plt.scatter(range(len(metric)), metric, c=['blue','black'], marker='D')
fig = plot_level1(metric, val, 'SSTA skewness(°C)','ENSO skewness', dtls)
return val, fig
def duration_14(input_pair, dtls, var_group):
# inputs pairs of model and obs

model = sst_regressed(input_pair[1][var_group])
obs = sst_regressed(input_pair[0][var_group])

months = np.arange(1, 73) - 36
counts = []
# Calculate the number of months where slope > 0.25 in the range -20 to 20
within_range = (months >= -30) & (months <= 30)
for slopes in [model, obs]:
slope_above_025 = slopes[within_range] > 0.25
counts.append(np.sum(slope_above_025))
val = compute(counts[1],counts[0])

fig = plot_level1(counts, val, 'Duration (reg > 0.25) (months)','ENSO duration', dtls)
return val, fig
def compute_enso_metrics(input_pair, dt_ls, var_group, metric): #['tos_patdiv1':, 'tos_pat2':]

# input_pair: obs first
if metric == '09pattern':
model_ssta = input_pair[1][var_group[1]]
model_nino34 = input_pair[1][var_group[0]]
reg_mod = lin_regress(model_ssta, model_nino34)
reg_obs = lin_regress(input_pair[0][var_group[1]], input_pair[0][var_group[0]])

val = np.sqrt(np.mean((np.array(reg_obs[1]) - np.array(reg_mod[1])) ** 2))
#save data? reg_mod as cube?
# plot functions? ylabel, title, data labels
fig = plot_level1([reg_mod,reg_obs], val, 'reg(ENSO SSTA, SSTA)', 'ENSO pattern', dt_ls)

elif metric =='10lifecycle':
model = sst_regressed(input_pair[1][var_group[0]]) #n34_cube
obs = sst_regressed(input_pair[0][var_group[0]])
val = np.sqrt(np.mean((obs - model) ** 2))
months = np.arange(1, 73) - 36 #build tuples?
# plot function #need xticks, labels as dict/ls
fig = plot_level1([ (months,model),(months,obs)], val, 'Degree C / C', 'ENSO lifecycle', dt_ls)

elif metric =='11amplitude':
data_values = [input_pair[1][var_group[0]].data.item(),input_pair[0][var_group[0]].data.item()]
val = compute(data_values[1], data_values[0])
#plt.scatter(range(len(metric)), metric, c=['blue','black'], marker='D')
fig = plot_level1(data_values, val, 'SSTA std (°C)', 'ENSO amplitude', dt_ls)

elif metric =='12seasonality':
data_values = []
for ds in input_pair: #obs 0, mod 1
preproc = {}
for seas in ['NDJ','MAM']:
cube = extract_season(ds[var_group[0]], seas)
cube = climate_statistics(cube, operator="std_dev", period="full")
preproc[seas] = cube.data

data_values.append(preproc['NDJ']/preproc['MAM'])

val = compute(data_values[1], data_values[0])
fig = plot_level1(data_values, val, 'SSTA std (NDJ/MAM)(°C/°C)','ENSO seasonality', dt_ls)

elif metric =='13asymmetry':
model_skew = skew(input_pair[1][var_group[0]].data, axis=0)
obs_skew = skew(input_pair[0][var_group[0]].data, axis=0)
data_values = [model_skew, obs_skew]

val = compute(data_values[1], data_values[0])
fig = plot_level1(data_values, val, 'SSTA skewness(°C)','ENSO skewness', dt_ls)

elif metric =='14duration':
model = sst_regressed(input_pair[1][var_group[0]])
obs = sst_regressed(input_pair[0][var_group[0]])

months = np.arange(1, 73) - 36
counts = []
# Calculate the number of months where slope > 0.25 in the range -20 to 20
within_range = (months >= -30) & (months <= 30)
for slopes in [model, obs]:
slope_above_025 = slopes[within_range] > 0.25
counts.append(np.sum(slope_above_025))
val = compute(counts[1], counts[0])

fig = plot_level1(counts, val, 'Duration (reg > 0.25) (months)','ENSO duration', dt_ls)
elif metric =='15diversity':
data_values = []
for ds in input_pair: #obs first
events = enso_events(ds[var_group[0]])
results_lon = diversity(ds[var_group[1]], events)
results_lon['enso'] = results_lon['nino'] + results_lon['nina']
logger.info(f"{dt_ls}, enso IQR: {iqr(results_lon['enso'])}")
data_values.append(iqr(results_lon['enso']))

val = compute(data_values[1], data_values[0])
fig = plot_level1(data_values, val, 'IQR of min/max SSTA(°long)','ENSO diversity', dt_ls)

return val, fig

def mask_to_years(events): # build time with mask
maskedTime = np.ma.masked_array(events.coord('time').points, mask=events.data.mask)
Expand All @@ -204,19 +202,6 @@ def diversity(ssta_cube, events_dict): #2 masks/events list

res_lon[enso] = loc_ls
return res_lon # return data to plot
def diversity_15(input_pair, dtls, var_groupls):
metric = []
for ds in input_pair: #obs first
events = enso_events(ds[var_groupls[0]])

results_lon = diversity(ds[var_groupls[1]], events)
results_lon['enso'] = results_lon['nino'] + results_lon['nina']
logger.info(f"{dtls}, enso IQR: {iqr(results_lon['enso'])}")
metric.append(iqr(results_lon['enso']))

val = compute(metric[1],metric[0])
fig = plot_level1(metric, val, 'IQR of min/max SSTA(°long)','ENSO diversity', dtls)
return val, fig

def iqr(data):
q3, q1 = np.percentile(data, [75 ,25])
Expand Down Expand Up @@ -299,21 +284,7 @@ def main(cfg):
logger.info(pformat(model_datasets))
# process function for each metric - obs first.. if, else
### make one function, with the switches - same params
if metric == '09pattern':
# sort datasetfiles
value, fig = pattern_09(input_pair, [dataset, obs[0]['dataset']])
elif metric =='10lifecycle':
value, fig = lifecycle_10(input_pair, [dataset, obs[0]['dataset']])
elif metric =='11amplitude':
value,fig = amplitude_11(input_pair, [dataset, obs[0]['dataset']], var_preproc[0])
elif metric =='12seasonality':
value,fig = seasonality_12(input_pair, [dataset, obs[0]['dataset']], var_preproc[0])
elif metric =='13asymmetry':
value,fig = asymmetry_13(input_pair, [dataset, obs[0]['dataset']], var_preproc[0])
elif metric =='14duration':
value,fig = duration_14(input_pair, [dataset, obs[0]['dataset']], var_preproc[0])
elif metric =='15diversity':
value,fig = diversity_15(input_pair, [dataset, obs[0]['dataset']], var_preproc)
value, fig = compute_enso_metrics(input_pair, [dataset, obs[0]['dataset']], var_preproc, metric)

# save metric for each pair, check not none
if value:
Expand Down
10 changes: 7 additions & 3 deletions recipe_diagnostics/diagnostic_scripts/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging

import pandas as pd
import numpy as np
from esmvaltool.diag_scripts.shared import (run_diagnostic,
save_figure)

Expand All @@ -17,7 +18,7 @@
def plot_matrix(diag_path):

metric_df = pd.read_csv(diag_path, header=None)
# TO DO: run normalisation on all these values
# run normalisation on all these values
metric_df[2] = (metric_df[2]-metric_df[2].mean())/metric_df[2].std()

transformls = []
Expand All @@ -31,11 +32,14 @@ def plot_matrix(diag_path):
plt.colorbar()
plt.xticks(range(len(matrixdf.columns)), matrixdf.columns, rotation=45, ha='right')
plt.yticks(range(len(matrixdf.index)), matrixdf.index, wrap=True)
plt.xticks(np.arange(matrixdf.shape[1] + 1) - 0.5, minor=True)
plt.yticks(np.arange(matrixdf.shape[0] + 1) - 0.5, minor=True)
plt.tick_params(which="both", bottom=False, left=False)
plt.grid(which="minor", color="black", linestyle="-", linewidth=0.5)

return figure



def main(cfg):
"""Read metrics and plot matrix."""
provenance_record = {
Expand All @@ -54,7 +58,7 @@ def main(cfg):

figure = plot_matrix(diag_path)

save_figure('plot_matrix', provenance_record, cfg, figure=figure)
save_figure('plot_matrix', provenance_record, cfg, figure=figure, bbox_inches='tight')

if __name__ == '__main__':

Expand Down
Binary file added recipe_diagnostics/figures/plot_matrix_enso.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 11 additions & 12 deletions recipe_diagnostics/recipe_enso_metrics.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ documentation:
datasets:
- {dataset: ACCESS-ESM1-5, project: CMIP6, mip: Omon, exp: historical, ensemble: r1i1p1f1, grid: gn, start_year: 1950, end_year: 2014}
- {dataset: ACCESS-CM2, project: CMIP6, mip: Omon, exp: historical, ensemble: r1i1p1f1, grid: gn, start_year: 1950, end_year: 2014}
# - {dataset: BCC-CSM2-MR, project: CMIP6, exp: historical, ensemble: r1i1p1f1, grid: gn, start_year: 1950, end_year: 2014}
# - {dataset: BCC-ESM1, project: CMIP6, exp: historical, ensemble: r1i1p1f1, grid: gn, start_year: 1950, end_year: 2014}
- {dataset: BCC-CSM2-MR, project: CMIP6, mip: Omon, exp: historical, ensemble: r1i1p1f1, grid: gn, start_year: 1950, end_year: 2014}
- {dataset: CAMS-CSM1-0, project: CMIP6, mip: Omon, exp: historical, ensemble: r1i1p1f1, grid: gn, start_year: 1950, end_year: 2014}

- {dataset: HadISST, project: OBS, type: reanaly, tier: 2, mip: Omon}

Expand Down Expand Up @@ -118,14 +118,13 @@ diagnostics:
plot_script:
script: /home/189/fc6164/esmValTool/repos/ENSO_recipes/recipe_diagnostics/diagnostic_scripts/enso_diag1metrics.py

# diag_collect:
# description: collect metrics
# variables:
# pr: #dummy variable to fill recipe req
# mip: Amon
# scripts:
# matrix_collect:
# script: /home/189/fc6164/esmValTool/repos/ENSO_recipes/recipe_diagnostics/matrix.py
# # above diagnostic name and script name
# diag_metrics: diagnostic_metrics/plot_script #cfg['work_dir']
diag_collect:
description: collect metrics
variables:
tos: #dummy variable to fill recipe requirements
scripts:
matrix_collect:
script: /home/189/fc6164/esmValTool/repos/ENSO_recipes/recipe_diagnostics/diagnostic_scripts/matrix.py
# above diagnostic name and script name
diag_metrics: diagnostic_metrics/plot_script #cfg['work_dir']

0 comments on commit 87b410f

Please sign in to comment.