import shutil
import re
import itertools
from functools import wraps
from pathlib import Path
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from cycler import cycler
from astropy import units as u
from . import profiles
from . import utils
from . import cores
from . import constants
from . import parsers
from . import shapes
# set up warning formatter
warnings.formatwarning = lambda msg, *_: f"Warning: {msg}\n"
# set up custom mpl cycler
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
linestyles = ['-', '--', ':', '-.']
style_cycler = cycler(linestyle=linestyles) * cycler(color=colors)
plt.rcParams['axes.prop_cycle'] = style_cycler
[docs]
def plotting(func):
@wraps(func)
def wrapper(self, ax=None, *args, **kwargs):
if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()
func(self, ax, *args, **kwargs)
return fig, ax
return wrapper
[docs]
def plotting_altitude(func):
@wraps(func)
def wrapper(self, ax=None, pressure=True, *args, **kwargs):
if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()
if pressure:
ax.set_yscale("log")
ax.set_ylabel("Pressure (bar)")
ax.set_ylim(self.core.min_pressure, self.core.max_pressure)
ax.invert_yaxis()
else:
ax.set_ylabel("Height (km)")
func(self, ax, pressure, *args, **kwargs)
return fig, ax
return wrapper
[docs]
class NemesisResult:
"""Class for storing and using the results of a NEMESIS retrieval.
Attributes:
core_directory (str): The directory of the core being analysed
core (NemesisCore): The NemesisCore object that generated the core directory
profiles (dict[Profile]): A dictionary of all the retrieved Profile objects from the run. The keys are the labels given to the Profiles on creation (eg. GasProfiles have the form "<gas_name> <isotope_id>" such as "PH3 0")
latitude (float): Latitude of the observed spectrum
longitude (float): Longitude of the observed spectrum
chi_sq (float): The chi-squared value of the retrieval
elapsed_time (float): The time taken for the retrieval in decimal hours
retrieved_spectrum (pandas.DataFrame): A DataFrame containing the measured and modelled spectrum
retrieved_aerosols (pandas.DataFrame): A DataFrame containing the retrieved aerosol profiles
retrieved_gases (pandas.DataFrame): A DataFrame containing the retrieved chemical profiles
"""
[docs]
def __init__(self, core_directory):
"""Constructor for NemesisResult
Args:
core_directory (str): The directory of the core
"""
# Load core directory
self.core_directory = Path(core_directory)
self.core = cores.load_core(self.core_directory)
self.core.spectrum = self.core_directory / "nemesis.spx"
self.profiles = self.core.profiles
# Parse some files
self.mre = parsers.NemesisMre(self.core_directory / "nemesis.mre")
self.aerosol_prf = parsers.AerosolRef(self.core_directory / "aerosol.prf")
self.nemesis_prf = parsers.NemesisRef(self.core_directory / "nemesis.prf")
self.aerosol_prf.data["pressure"] = self.nemesis_prf.data["pressure"]
# Parse the iterations file for retrievals
if not self.core.forward:
self.itr = parsers.NemesisItr(self.core_directory / "nemesis.itr")
self.itr.add_column_names(self.profiles)
# Add the mre object attributes to the NemesisResult object for convenience
self.__dict__ |= self.mre.__dict__
self.retrieved_aerosols = self.aerosol_prf.data
self.retrieved_gases = self.nemesis_prf.data
self.chi_sqs = self._get_chi_squareds()
self.chi_sq = self.chi_sqs[-1]
# Add results to the profiles
self._add_results_to_profiles()
# Set some misc params
self._time = None
def _add_results_to_profiles(self):
for label, profile in self.profiles.items():
if isinstance(profile.shape, shapes.Shape0):
print("Warning, Shape0 is not fully supported (from NemesisResult._add_results_to_profiles)")
continue
if isinstance(profile, profiles.AerosolProfile) and profile.retrieve_optical:
profile._add_result(self.mre.retrieved_parameters.pop(0),
self.mre.retrieved_parameters.pop(0))
else:
profile._add_result(self.mre.retrieved_parameters.pop(0))
def _get_chi_squareds(self):
"""Read the .prc file and cache the chi squared values to a new .chi file if not already done.
Otherwise, read the .chi file and return the chi squared values
Args:
None
Returns:
None
Creates:
nemesis.chi"""
prc = parsers.NemesisPrc(self.core_directory / "nemesis.prc")
prc.write_chisqs(self.core_directory / "nemesis.chi")
return prc.chisq
def _convert_aerosol_units(self, data, p_ref=1):
"""
Convert the aerosol densities in the aerosol.prf file to more useful units. This assumes that the cross-sections are normalised at a specific wavelength, such that the native units are cm2/g.
This function converts them to optical depth per km, optical depth per bar, and pressure-integrated optical depth.
Args:
data (np.array): Aerosol densities in cm2/g
p_ref (float): The reference pressure in bar at which to calculate everything. Default is 1 bar and should not need to be changed.
Returns:
opacity_ODkm (np.array): Aerosol densities in optical depth per km
opacity_ODbar (np.array): Aerosol densities in optical depth per bar
opacity_ODkm_int (np.array): Aerosol densities in optical depth integrated over height
opacity_ODbar_int (np.array): Aerosol densities in optical depth integrated over pressure
"""
# first, calculate the average molecular weight of the atmosphere
masses = []
for gas in self.core.reference.data.columns:
if gas in ("height", "pressure", "temperature"):
continue
g = gas.split(" ")[0] # ignore isotope weights
m = constants.GASES.loc[constants.GASES["name"] == g, "molar_mass"].iloc[0]
masses.append(m)
vmrs = self.core.reference.data.copy()
x = vmrs.drop(columns=["height", "pressure", "temperature"])
vmrs["mean_molar_wt"] = (x * masses).sum(axis=1)
# then calculate the values at the target level
target = vmrs.loc[utils.find_nearest(self.core.reference.data["pressure"], p_ref)[0]]
pres = target["pressure"] * u.bar
temp = target["temperature"] * u.K
mol_wt = target["mean_molar_wt"] * u.g / u.mol
gravity = constants.GRAVITY[self.core.planet] * u.m / u.s**2
R = 8.314 * u.J / u.K / u.mol
Hscale = ((R * temp) / (mol_wt * gravity))
# print("Atmosphere scale height: ", Hscale.decompose())
rho = ((mol_wt * pres) / (R * temp))
# print("Atmospheric density: ", rho.decompose())
opacity_ODkm = list(data) * u.cm**2/u.g * rho
opacity_ODkm = opacity_ODkm.to(1 / u.km)
# print("Maximum optical depth per km: ", max(opacity_ODkm))
opacity_ODbar = opacity_ODkm * Hscale / pres
opacity_ODbar = opacity_ODbar.to(1 / u.bar)
# print("Maximum optical depth per bar: ", max(opacity_ODbar))
opacity_ODbar_int = np.cumsum(opacity_ODbar)
return opacity_ODkm.value, opacity_ODbar.value, opacity_ODbar_int.value
[docs]
def print_summary(self, colors=False):
print(f"Summary of retrieval in {self.core_directory}")
print()
print(f"Time taken: {utils.format_decimal_hours(self.elapsed_time)}")
print(f"Number of iterations: {len(self.chi_sqs)}")
print("Time per iteration: " + utils.format_decimal_hours(self.elapsed_time / len(self.chi_sqs)))
print(f"Number of retrieved parameters: {len(self.mre.initial_state_vector)}")
print(f"Chi squared value: {self.chi_sq}")
print()
for name, profile in self.profiles.items():
profile.print_table(colors=colors, forward=self.core.forward)
@property
def elapsed_time(self):
"""Return the time taken for the retrieval in hours"""
if self._time is None:
with open(self.core_directory / "nemesis.prc") as file:
lines = file.read().splitlines()
time = float(re.findall(r"[-+]?\d*\.\d+|\d+", lines[-1])[0])
return time / 3600
else:
return self._time
[docs]
@plotting
def plot_chisq(self, ax):
"""Plot the chi-squared values as a function of iteration number
Args:
ax: The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted
"""
if self.core.forward:
ax.scatter(0, self.chi_sq)
else:
ax.plot(self.chi_sqs)
ax.axhline(y=1, ls="dashed")
ax.set_xlabel("Iteration Number")
ax.set_ylabel("$\chi^2$")
[docs]
@plotting
def plot_spectrum(self, ax, show_chisq=True, legend=True, log=False):
"""Plot the measured and model spectrum on a matplotlib Axes.
Args:
ax: The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes
show_chisq (bool): Whether to display the chi-squared value of the fit
legend (bool): Whether to draw the legend
log (bool): Whether to use a log plot
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
if log:
ax.set_yscale("log")
ax.plot(self.retrieved_spectrum.wavelength, self.retrieved_spectrum.measured, lw=0.5, label="Measured" if legend else None)
ax.fill_between(self.retrieved_spectrum.wavelength, self.retrieved_spectrum.measured-self.retrieved_spectrum.error, self.retrieved_spectrum.measured+self.retrieved_spectrum.error, alpha=0.5)
ax.plot(self.retrieved_spectrum.wavelength, self.retrieved_spectrum.model, c="r", lw=0.5, label="Model" if legend else None)
if self.chi_sq is not None and show_chisq:
plt.text(0.95, 0.05, f"$\chi^2 = ${self.chi_sq:.3f}",
horizontalalignment='right',
verticalalignment='bottom',
transform = ax.transAxes)
ax.set_xlabel("Wavelength (μm)")
ax.set_ylabel("Radiance\n(μW cm$^{-2}$ sr$^{-1}$ μm$^{-1}$)")
if legend:
ax.legend()
[docs]
@plotting
def plot_spectrum_residuals(self, ax, log=False):
"""Plot the spectrum residuals on a matplotlib Axes.
Args:
ax: The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes
show_chisq (bool): Whether to display the chi-squared value of the fit
log (bool): Whether to use a log plot
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
if log:
residuals = np.log(self.retrieved_spectrum.model) - np.log(self.retrieved_spectrum.measured)
ax.plot(self.retrieved_spectrum.wavelength, residuals, label="Residuals", lw=1)
ax.fill_between(self.retrieved_spectrum.wavelength,
-self.retrieved_spectrum.error / self.retrieved_spectrum.measured,
self.retrieved_spectrum.error / self.retrieved_spectrum.measured,
alpha=0.5, label="Error")
ax.set_ylabel("Log Residuals")
else:
residuals = self.retrieved_spectrum.model - self.retrieved_spectrum.measured
ax.plot(self.retrieved_spectrum.wavelength, residuals, label="Residuals", lw=1)
ax.fill_between(self.retrieved_spectrum.wavelength,
residuals-self.retrieved_spectrum.error,
residuals+self.retrieved_spectrum.error,
alpha=0.5, label="Error")
ax.set_ylabel("Residuals\n(μW cm$^{-2}$ sr$^{-1}$ μm$^{-1}$)")
ax.axhline(y=0, zorder=-1, c="k", ls="dashed")
ax.set_xlabel("Wavelength (μm)")
ax.legend()
[docs]
@plotting_altitude
def plot_temperature(self, ax, pressure):
"""Plot the prior and retrieved temperature profile on a matplotlib Axes.
Args:
ax: The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes
pressure: Whether to plot the temperature profile against pressure (if True) or height (if False)
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
# Get the appropriate y axis data
y = self.core.reference.pressure if pressure else self.core.reference.height
# Find retrieved temperature profile
temp_profile = None
for p in self.profiles:
if isinstance(p, profiles.TemperatureProfile):
temp_profile = p
ax.plot(temp_profile.shape.data.retrieved, y, c="k", lw=0.5, label="Retrieved")
ax.plot(temp_profile.shape.data.prior, y, c="r", lw=0.5, label="Prior")
ax.legend()
break
# If temperature profile not retrieved, plot the temperature profile in the .ref file
if temp_profile is None:
ax.plot(self.core.reference.temp, y)
ax.set_xlabel("Temperature (K)")
[docs]
@plotting_altitude
def plot_aerosol_profiles(self, ax, pressure, unit="tau/bar"):
"""Plot the retrieved aerosol profiles either in units of particles/gram of atmosphere or in units of optical thickness/bar
against either height or pressure
Args:
ax: The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes
unit: The unit to convert the aerosol profiles to. Valid values are:
'tau/bar' for optical depth per bar,
'tau/km' for optical depth per bar,
'tau' for pressure-integrated optical depth,
'cm2/g' for the native prf units
pressure: Whether to plot the aerosol profiles against pressure (if True) or height (if False)
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
# Iterate over every retrieved aerosol
max_value = -1
if self.core.num_aerosol_modes == 0:
return
for label in self.retrieved_aerosols.columns:
if label in ("height", "pressure"):
continue
x = self.retrieved_aerosols[label]
y = self.retrieved_aerosols.pressure if pressure else self.retrieved_aerosols.height
tau_per_km, tau_per_bar, tau_integrated_bar = self._convert_aerosol_units(x)
if unit == "tau/bar":
unit_label = f"Optical depth / bar"
x = tau_per_bar
elif unit == "tau/km":
unit_label = f"Optical depth / km"
x = tau_per_km
elif unit == "tau":
unit_label = f"Pressure-integrated optical depth"
x = tau_integrated_bar
elif unit == "cm2/g":
unit_label = f"Aerosol density (cm$^2$ / gram)"
else:
raise ValueError(f"Invalid unit! - Must be one of 'tau/bar' 'tau/km' 'tau' or 'cm2/g' - not {unit}")
unit_label += f" at {self.core.reference_wavelength:.1f}µm"
if x.max() > max_value:
max_value = x.max()
ax.plot(x, y, label=label)
ax.set_xlabel(unit_label)
ax.set_xscale("log")
ax.set_xlim(1e-4, 1e2)
ax.legend()
[docs]
@plotting_altitude
def plot_gas_profiles(self, ax,
pressure,
unit="",
gas_names=None,
plot_retrieved_profiles=True,
plot_prior_profiles=False,
plot_ref_profiles=True):
"""Plot gas profiles from the .prf file.
Args:
ax (matplotlib.Axes): The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes.
pressure (bool): Whether to plot the gas profiles against pressure (if True) or height (if False).
unit (str): One of '', 'ppm', 'ppb' or 'ppt'.
gas_names (list[str], optional): List of gas names to plot. If None, plot all gases.
plot_retrieved_profiles (bool): Whether to plot the retrieved gas profiles
plot_prior_profiles (bool): Whether to plot the prior gas profiles
plot_ref_profiles (bool): Whether to plot the profile in the .ref file
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
# If no gas names specififed then get every gas profile
if gas_names is None:
gas_names = [x for x in self.retrieved_gases.columns if x not in ("height", "pressure", "temperature")]
# Allow passing in of a single string instead of a list
elif isinstance(gas_names, str):
gas_names = [gas_names]
# Get the appropriate y axes
y = self.retrieved_gases["pressure"] if pressure else self.retrieved_gases["height"]
y2 = self.core.reference.data["pressure"] if pressure else self.core.reference.data["height"]
# Get the prior distributions if requested
if plot_prior_profiles:
warnings.warn("Generating the prior distributions may be broken! Check outputs carefully!")
priors = self.core.generate_prior_distributions()
y3 = priors["pressure"] if pressure else priors["height"]
# Determine the scaling factor
if unit == "":
scale = 1
elif unit == "ppm":
scale = 1e6
elif unit == "ppb":
scale = 1e9
elif unit == "ppt":
scale = 1e12
else:
raise ValueError("Invalid unit! - Must be one of '', 'ppm', 'ppb', 'ppt'")
# Plot the profiles
colors = itertools.cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
for i, gas_name in enumerate(gas_names):
c = next(colors)
if plot_retrieved_profiles:
ax.plot(self.retrieved_gases[gas_name]*scale, y, label=gas_name, c=c)
if plot_ref_profiles:
ax.plot(self.core.reference.data[gas_name]*scale, y2, label=f"{gas_name} Reference", c=c, ls="dashed")
if plot_prior_profiles:
ax.plot(priors[gas_name]*scale, y3, label=f"{gas_name} Prior", c=c, ls="dotted")
# Set a limit on lowest VMR
x1, x2 = ax.get_xlim()
if x1*scale < 1e-12:
ax.set_xlim(1e-12*scale, 1*scale)
# Set unit label
if unit == "":
label = unit
else:
label = f"({unit})"
ax.set_xscale("log")
ax.set_xlabel(f"Volume Mixing Ratio {label}")
ax.legend()
[docs]
def make_summary_plot(self, figsize=(11, 10), log=False):
"""Make a summary plot with prior and retrieved spectra, error on the spectra, aerosol and chemical profiles, and chi-squared values.
Args:
figsize (int, int): matplotlib figure size
Returns:
matplotlib.Figure: The produced figure
dict(str: matplotlib.Axes): The axes of the produced figure with labels:
'A' for the spectrum
'B' for the residuals
'C' for the chi-sqaured plot
'D' for the aerosol profiles
'E' for the gas profiles"""
names = []
for label, profile in self.profiles.items():
if isinstance(profile, profiles.GasProfile):
names.append(label)
fig, axs = plt.subplot_mosaic("AAA\nBBB\nCDE",
gridspec_kw={"hspace": 0.25, "wspace": 0.35},
figsize=figsize)
axs["A"].sharex(axs["B"])
self.plot_spectrum(ax=axs["A"], log=log)
self.plot_spectrum_residuals(ax=axs["B"], log=log)
self.plot_chisq(ax=axs["C"])
self.plot_aerosol_profiles(ax=axs["D"])
self.plot_gas_profiles(ax=axs["E"], gas_names=names)
if self.core.forward:
fig.suptitle(f"Forward model in {self.core_directory.resolve()}", y=0.91)
else:
fig.suptitle(f"Retrieval in {self.core_directory.resolve()}", y=0.91)
fig.savefig(self.core_directory / f"plots/summary{'_log' if log else ''}.png", bbox_inches="tight", dpi=400)
return fig, axs
[docs]
def make_iterations_plot(self, figsize=(14, 10)):
"""Plot the state vector for each iteration of the retrieval. Each retreived parameter is plotted on a separate axis.
Args:
None
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
if self.core.forward:
raise TypeError("Forward models cannot have interations plotted")
data = self.itr.state_vectors
nrows = int(np.ceil(np.sqrt(len(data.columns))))
ncols = int(np.ceil(len(data.columns) / nrows))
fig, axs = plt.subplots(nrows, ncols, figsize=figsize, sharex=True)
for ax, param in itertools.zip_longest(axs.reshape(-1), data):
if param is None:
ax.axis("off")
continue
ax.plot(data[param])
ax.set_ylabel(param)
fig.supxlabel('Iteration Number')
fig.tight_layout()
fig.savefig(self.core_directory / "plots/iterations.png", bbox_inches="tight", dpi=400)
return fig, axs
[docs]
def savefig(self, name, fig=None, **kwargs):
"""
Save a matplotlib figure to a file in the core's `plots` directory.
Args:
name (str): The name of the file to save the figure as.
fig (matplotlib.Figure, optional): The figure to save. If None, the current figure will be saved. Default is None.
**kwargs: Additional keyword arguments to pass to `savefig`.
Returns:
None
"""
x = plt if fig is None else fig
x.savefig(self.core_directory / "plots" / name, bbox_inches="tight", **kwargs)
[docs]
def delete(self, confirm=True):
"""Delete the NemesisResult object AND delete the corresponding core directory. This action is irreviersible!
Args:
confirm (bool): Whether to prompt for confirmation
Returns:
None"""
if confirm:
answer = input(f"Are you sure you want to delete this core ({self.core_directory})? y/n ").lower()
if answer != "y":
return
shutil.rmtree(self.core_directory)
del self
[docs]
class SensitivityAnalysis:
"""Class for analysing the sensitivity cores generated by `eleos.cores.create_sensitivity_analysis`
Attributes:
parent_directory (str): The directory containing the sensitivity cores
results (list[NemesisResult]): A list of NemesisResult objects for each core
baseline (NemesisResult): The baseline result, an alias for results[0]
params (pandas.DataFrame): A DataFrame detailing which parameters were varied in each core and their value
"""
[docs]
def __init__(self, parent_directory):
"""
Args:
parent_directory (str): The directory containing all the individual cores
"""
self.parent_directory = Path(parent_directory)
self.results = load_multiple_cores(parent_directory, failed='warn')
self.baseline = self.results[0]
self.params = pd.read_csv(self.parent_directory / "sensitivity_analysis.txt")
def _get_all_params(self):
out = []
prev = (None, None)
for i, row in self.params.iterrows():
x = (row["Profile Label"], row["Parameter"])
if x != prev:
out.append((row["Profile Label"], row["Parameter"]))
prev = x
return out
def _get_params(self, profile_label, parameter):
"""Filter the params DataFrame to get the rows that match the given profile label and parameter."""
return self.params[(self.params["Parameter"] == parameter) & (self.params["Profile Label"] == profile_label)]
[docs]
def get_results(self, profile_label, parameter):
"""Get the cores that varied the given parameter in the given profile.
Args:
profile_label: The label of the profile
parameter: The parameter that was varied
Returns:
list[NemesisResult]: A list of NemesisResult objects that varied the given parameter in the given profile"""
out = []
for _, case in self._get_params(profile_label, parameter).iterrows():
out.append(self.results[case["Core ID"] - 1])
return out
[docs]
@plotting
def plot_parameter(self, ax, profile_label, parameter):
"""Plot the sensitivity of the model to the given parameter in the given profile.
Args:
ax: The matplotlib.Axes object to plot to. If omitted then create a new Figure and Axes
profile_label: The label of the profile to plot
parameter: The variable to plot
Returns:
matplotlib.Figure: The Figure object to which the Axes belong
matplotlib.Axes: The Axes object onto which the data was plotted"""
def alpha_map(x, V, min_alpha):
return 1 - (np.abs(x-1)/V) * (1 - min_alpha)
def f(x, y):
if x == 1:
return "0"
return f"{((x-1)*100):+.0f}%"
df = self._get_params(profile_label, parameter)
ress = self.get_results(profile_label, parameter)
base = self.baseline.retrieved_spectrum.model
for factor, r in zip(df["Factor"], ress):
y = r.retrieved_spectrum.model / base
# ax.plot(r.retrieved_spectrum.wavelength,
# y,
# alpha=alpha_map(factor, 1-df["Factor"].min(), 0.25),
# color="#FF0000" if factor > 1 else "#0044FF",
# label=factor)
ax.fill_between(r.retrieved_spectrum.wavelength,
y,
1,
alpha=0.2,
color="#FF0000" if factor > 1 else "#0044FF")
low, high = ax.get_ylim()
bound = np.max((np.abs(low-1), np.abs(high-1)))
ax.set_ylim(-bound+1, bound+1)
ax.set_ylabel(f"Change from baseline")
ax.set_xlabel("Wavelength (µm)")
ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(f))
ax.axhline(1, c="k", ls="dashed")
[docs]
def make_parameters_plot(self, ncol=3):
p = self._get_all_params()
nrow = int(np.ceil(len(p) / ncol))
fig, axs = plt.subplots(nrow , ncol, figsize=(4*ncol, 1.5*nrow), sharex=False)
axs = axs.flatten()
for ax in axs[-(nrow*ncol - len(p)):]:
ax.set_axis_off()
for ax, name in zip(axs, p):
self.plot_parameter(ax, *name)
ax.set_ylabel("")
ax.set_xlabel("")
ax.text(0.99, 0.05, " ".join(name).replace("_", " "), transform=ax.transAxes, ha="right", va="bottom")
fig.supxlabel("Wavelength (µm)")
fig.supylabel("Radiance change from baseline", x=0.01)
fig.tight_layout()
fig.savefig("plots/sensitivity.png", dpi=300)
[docs]
def savefig(self, name, fig=None, **kwargs):
"""
Save a matplotlib figure to a file in the parent_directory.
Args:
name (str): The name of the file to save the figure as.
fig (matplotlib.Figure, optional): The figure to save. If None, the current figure will be saved. Default is None.
**kwargs: Additional keyword arguments to pass to `savefig`.
Returns:
None
"""
x = plt if fig is None else fig
x.savefig(self.parent_directory / name, bbox_inches="tight", **kwargs)
[docs]
class GasAnalysis:
[docs]
def __init__(self, parent_directory):
self.parent_directory = Path(parent_directory)
self.gas_index = pd.read_csv(self.parent_directory / "gasindex.csv", keep_default_na=False, na_values=['NaN'])
self.spectra = dict()
for _, (core_id, excluded_gas) in self.gas_index.iterrows():
r = NemesisResult(self.parent_directory / f"core_{core_id}")
self.spectra[excluded_gas] = r.retrieved_spectrum
def _get_biggest_contributors(self):
baseline = self.spectra["None"]
biggas = []
for i in range(len(baseline.wavelength)):
max_diff = 0
max_gas = None
for gas, spectrum in self.spectra.items():
if gas == "None":
continue
diff = np.abs(spectrum.model.iloc[i] - baseline.model.iloc[i])
if diff > max_diff:
max_diff = diff
max_gas = gas
biggas.append(max_gas)
return biggas
[docs]
@plotting
def plot_contributions_2D(self, ax):
ax.set_ylabel("Gas contribution (arb. units)")
reference_spectrum = self.spectra["None"]
for gas, spectrum in self.spectra.items():
if gas == "None":
continue
diff = spectrum.model - reference_spectrum.model
ax.plot(spectrum.wavelength, diff, label=gas)
ax.legend(ncols=2)
[docs]
@plotting
def plot_contributions_3D(self, ax, normalise=False, log_scale=False, cmap="viridis"):
warnings.warn("not fully tested!")
diffs = []
reference_spectrum = self.spectra["None"]
for gas, spectrum in self.spectra.items():
if gas == "None":
continue
diff = np.clip(spectrum.model - reference_spectrum.model, 1e-20, np.inf)
if normalise:
diff /= np.nanmax(np.abs(diff))
diffs.append(diff)
if log_scale:
norm = mpl.colors.LogNorm(vmin=np.nanmax([np.nanmin(diffs), 1e-10]), vmax=np.nanmax(diffs))
else:
norm = None
ax.pcolormesh(spectrum.wavelength, range(len(diffs)), diffs, cmap=cmap, norm=norm)
ax.set_yticks(range(len(diffs)), labels=[gas for gas in self.spectra.keys() if gas != "None"])
[docs]
@plotting
def plot_highlighted_spectrum(self, ax):
baseline = self.spectra["None"]
biggas = self._get_biggest_contributors()
ax.plot(baseline.wavelength, baseline.model, c='k', lw=1, marker="o")
ax.set_ylabel("Radiance\n(μW cm$^{-2}$ sr$^{-1}$ μm$^{-1}$)")
ylim = ax.get_ylim()
# Get unique gases and create a mapping to numeric indices
gas_to_index = {gas: i for i, gas in enumerate(set(biggas))}
# Create gradient data - map each gas to its index
gradient = np.array([gas_to_index[gas] for gas in biggas]).reshape(1, -1)
gradient = np.vstack((gradient, gradient))
# Create a colormap with unique colors for each gas
cmap = plt.cm.get_cmap('tab20', len(gas_to_index))
# Add the gradient as background
ax.pcolormesh(baseline.wavelength, [0, 1], gradient, cmap=cmap, zorder=-1)
# Create a legend
legend_elements = [mpl.patches.Patch(facecolor=cmap(i), label=gas) for gas, i in gas_to_index.items()]
ax.legend(handles=legend_elements, loc='upper right')
[docs]
def load_multiple_cores(parent_directory, failed='warn'):
"""Read in all the cores in a given directory and return a list of NemesisResult objects.
Args:
parent_directory (str): The directory containing all the individual core directories
failed (str): Whether to raise an error if a retieval failed ('raise') or to warn and skip it ('warn'), or silently skip ('skip')
Returns:
list[NemesisResult]: A list containing the result object for each core"""
parent_directory = Path(parent_directory)
assert failed in ("raise", "warn", "skip"), "failed parameter must be one of 'raise', 'warn' or 'skip'"
out = []
for core in sorted(parent_directory.glob("core_*"), key=sort_key_paths):
try:
out.append(NemesisResult(core))
except Exception as e:
if failed == "warn":
warnings.warn(f"Failed to load core {core}")
if failed == "raise":
raise e
return out
[docs]
def load_best_cores(parent_directory, n, failed='warn'):
"""Load n cores from the parent_directory with the lowest chi-squared values.
Args:
parent_directory (str): The directory containing all the individual core directories
n (int): The number of cores to load
Returns:
list[NemesisResult]: A list containing the result object for each core, sorted by chi-squared value (so lowest chi-sqared is index 0 in the list)
"""
parent_directory = Path(parent_directory)
best = [] # list of tuples: (chisq, core_directory)
for core_directory in parent_directory.glob("core_*"):
prc_path = core_directory / "nemesis.prc"
try:
prc = parsers.NemesisPrc(prc_path)
chisq = prc.chisq[-1]
except (FileNotFoundError, IndexError, Exception) as e:
if failed == "raise":
raise
elif failed == "warn":
warnings.warn(f"Skipping {core_directory}: {e}")
continue
if len(best) < n:
best.append((chisq, core_directory))
else:
# replace worst if current is better
worst_idx = np.argmax([b[0] for b in best])
if chisq < best[worst_idx][0]:
best[worst_idx] = (chisq, core_directory)
best.sort(key=lambda x: x[0])
results = [NemesisResult(core_dir) for _, core_dir in best]
return results
[docs]
def load_parallelised_cores(parent_directory):
"""Given a directory of cores generated with cores.parallelise_forward, load them and recombine into a
single NemesisResult object. The chi-sq of the core will be set to None, and the chi_sqs attribute will
be set to the chi squared values of each chunk (ie, chi squared as a coarse function of wavelength)
Args:
parent_directory (str): Path to the directory of cores
Returns:
NemesisResult: The recombined results object
"""
ress = load_multiple_cores(parent_directory, failed="raise")
retspec = None
chis = []
for r in ress:
if retspec is None:
retspec = r.retrieved_spectrum
else:
retspec = pd.concat([retspec, r.retrieved_spectrum], ignore_index=True)
chis.append(r.chi_sq)
ress[0].retrieved_spectrum = retspec
ress[0].chi_sq = None
ress[0].chi_sqs = chis
return ress[0]
[docs]
def sort_key_paths(path):
x = str(path).split("_")[-1]
return int(x)