import numpy as np
import pandas as pd
import glob
import h3ppy
import planetmapper
import copy
import warnings
from astropy.io import fits
from collections import defaultdict
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.signal import fftconvolve
from pprint import pprint
from . import utils
from . import spx
from . import constants
[docs]
def margin_trim(cube, margin_size=3):
"""
Set the outer rings of spaxels in a cube to NaN.
Args:
cube (np.ndarray): The spectral cube to trim with shape (wavelengths, x, y).
margin_size (int): The size of the margin to trim from each edge (eg. 2 will set every spaxel within 2 spaxels of the edge to NaN).
Returns:
np.ndarray: The trimmed spectral cube with NaN values in the margins.
"""
if margin_size == 0:
return cube
cube[:, :margin_size] = np.nan
cube[:, :, :margin_size] = np.nan
cube[:, -margin_size:] = np.nan
cube[:, :, -margin_size:] = np.nan
return cube
[docs]
def downsample_spectrum(wavelengths, spectrum, errors=None, num_points=100, special_regions=None):
"""Downsample a spectrum to a given number of spectral points.
Args:
wavelengths (np.ndarray): The corresponding wavelengths
spectrum (np.ndarray): The spectral data
error (np.ndarray): The error on the spectral data (optional)
num_points (int): The number of points to reduce the spectrum to
special_regions (dict([int, int]: int)):
A dict of regions of the form
{(start, end): num_points} where the spectrum should be downsampled
to a specific number of points. Use -1 for full-resolution sampling
Returns:
wavelengths (np.ndarray): The trimmed wavelengths
spectrum (np.ndarray): The trimmed spectrum
errors (np.ndarray): The trimmed errors (if provided)"""
wavelengths = np.asarray(wavelengths)
spectrum = np.asarray(spectrum)
if errors is not None:
errors = np.asarray(errors)
# Keep track of indices to include
selected_indices = set(np.linspace(0, len(wavelengths) - 1, num_points, dtype=int))
# Handle special regions
if special_regions:
for (start, end), region_points in special_regions.items():
region_mask = (wavelengths >= start) & (wavelengths <= end)
region_indices = np.where(region_mask)[0]
# Keep all points in this region
if region_points == -1:
selected_indices.update(region_indices)
# Or downsample appropriately
else:
if region_points > len(region_indices):
region_points = len(region_indices)
downsampled = np.linspace(0, len(region_indices) - 1, region_points, dtype=int)
selected_indices.update(region_indices[downsampled])
# Final selection
final_indices = sorted(selected_indices)
if errors is not None:
return wavelengths[final_indices], spectrum[final_indices], errors[final_indices]
return wavelengths[final_indices], spectrum[final_indices]
[docs]
def resample_to_new(wavelength_source, spectrum_source, wavelength_target, kind="linear", fill_value="extrapolate"):
"""Resample a spectrum onto a new wavelength grid via interpolation.
Args:
wavelength_source (array-like): Original wavelength values of the spectrum (must be sorted).
spectrum_source (array-like): Spectrum values corresponding to `wavelength_source`.
wavelength_target (array-like): Target wavelength grid to resample onto.
kind (str, optional): Interpolation type. Options include
"linear", "nearest", "cubic", etc. Defaults to "linear".
fill_value (str or float, optional): Value to use outside the range of
`wavelength_source`. If "extrapolate", allows extrapolation. Defaults to "extrapolate".
Returns:
np.ndarray: Spectrum values resampled onto the `wavelength_target` grid.
"""
wavelength_source = np.asarray(wavelength_source)
spectrum_source = np.asarray(spectrum_source)
wavelength_target = np.asarray(wavelength_target)
if len(wavelength_source) != len(spectrum_source):
raise ValueError("wavelength_source and spectrum_source must have the same length.")
interp_func = interp1d(wavelength_source, spectrum_source, kind=kind,
fill_value=fill_value, bounds_error=False)
return interp_func(wavelength_target)
[docs]
def wavelength_select(wavelengths, spectrum, errors=None, min_wl=None, max_wl=None, epsilon=0):
"""Select a range of wavelengths from a spectrum.
Args:
wavelengths (np.ndarray): The wavelength grid
spectrum (np.ndarray): The spectral data
errors (np.ndarray): The error on the spectral data (optional)
min_wl (float): The minimum wavelength to select
max_wl (float): The maximum wavelength to select
epsilon (float): A small fudge factor to the end of the mask to fix FPE when grouping spectra
Returns:
wavelengths (np.ndarray): The trimmed wavelengths
spectrum (np.ndarray): The trimmed spectrum
errors (np.ndarray): The trimmed errors (if provided)"""
mask = np.ones_like(wavelengths, dtype=bool)
if min_wl is not None:
mask &= wavelengths >= min_wl
if max_wl is not None:
mask &= wavelengths <= max_wl + epsilon
if errors is not None:
return wavelengths[mask], spectrum[mask], errors[mask]
return wavelengths[mask], spectrum[mask]
[docs]
def interpolate_nans(wavelengths, spectrum):
mask = ~np.isnan(spectrum)
return np.interp(wavelengths, wavelengths[mask], spectrum[mask])
[docs]
def subtract_non_lte(wavelengths, spectrum):
"""Utility function for combining subtract_h3p and subtract_ch4.
For more fine-grained control, use those function independently.
Args:
wavelengths (np.ndarray): Wavelength grid (microns)
spectrum (np.ndarray): Spectral data (in units of W/cm2/sr/um)
Returns:
np.ndarray: The wavelength grid (unchanged)
np.ndarray: The spectrum with H3+ and CH4 removed"""
h3p_subtracted = subtract_h3p(wavelengths, spectrum)
ch4_subtracted = subtract_ch4(wavelengths, h3p_subtracted)
return wavelengths, ch4_subtracted
[docs]
def subtract_h3p(wavelengths, spectrum, latitude=-60, region=(3.525, 3.55), return_model=False, **h3p_kwargs):
"""Subtract H3+ emission from a spectrum
Args:
wavelengths (np.ndarray): Wavelength grid
spectrum (np.ndarray): Measured spectrum in units of W/cm2/sr/um
latitude (float): Optionally specify a latitude to make a more informed guess at initial values
region (float, float): The region to use to fit the H3+ model (by default this is a triple of bright lines)
return_model (bool): If True, then return the h3p object used for the fit
Returns:
np.ndarray: Spectrum with H3+ subtracted in units of W/cm2/sr/um
h3ppy.h3p: If return_model, the h3p object that was used to fit
"""
def remap(x, in_min, in_max, out_min, out_max):
return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min
temp = remap(np.abs(latitude), 40, 90, 500, 1000)
h3p_kwargs.setdefault("temperature", temp)
h3p_kwargs.setdefault("R", 2700)
spectrum = copy.deepcopy(spectrum) * 10000 # Convert to W/m2/sr/um
nanmask = np.isnan(spectrum)
spectrum = interpolate_nans(wavelengths, spectrum)
w, s = wavelength_select(wavelengths, spectrum, min_wl=region[0], max_wl=region[1])
h3p = h3ppy.h3p()
h3p.set(**h3p_kwargs, wave=w, data=s)
h3p.guess_density()
fit = h3p.fit(verbose=False)
h3p.set(wave=wavelengths)
h3p_pred = h3p.model(background_0=0)
subtracted = (spectrum - h3p_pred) / 10000 # Convert back to W/cm2/sr/um
subtracted[nanmask] = np.nan # Restore NaNs where they were in the original spectrum
subtracted[subtracted < 0] = np.nan # Remove any negative values
if return_model:
return subtracted, h3p
else:
return subtracted
[docs]
def subtract_ch4(wavelengths, spectrum, nlines=25, linewidth=0.008, A0=1e-6, lineshape="voigt", return_model=False):
"""Subtract the brightest N individual CH4 lines from a spectrum.
Args:
wavelengths (np.ndarray): The wavelength grid
spectrum (np.ndarray): The spectral data in units of W/cm2/sr/um
nlines (int): The number of peaks to fit
linewidth (float): The intial guess for the line FWHM in microns
A0 (float): The initial guess for the amplitude of the lines
lineshape (str): The shape of the lines to use, either 'gaussian', 'lorentzian', or 'voigt'
return_model (bool): If True, then return the model spectrum as well as the final spectrum
Returns:
np.ndarray: The spectrum with CH4 subtracted in units of W/cm2/sr/um
"""
raise NotImplementedError("CH4 subtraction is dodgy at best. Catch this error if you are sure you want to try")
def multifunc(x, func, p0s):
total = np.zeros_like(x)
for p in p0s:
total += func(x, *p)
return total
def fit(wavelengths, spectrum, line, linewidth, c="magenta"):
if not np.nanmin(wavelengths) < line < np.nanmax(wavelengths):
print("fail")
return
w, s = wavelength_select(wavelengths,
spectrum,
min_wl=line - linewidth/1.5,
max_wl=line + linewidth/1.5)
offset_0 = np.nanmedian(np.nanpercentile(s, 25))
plt.hlines(y=offset_0/1e8, xmin=line-linewidth/2, xmax=line+linewidth/2, color=c)
try:
popt, pcov = curve_fit(eval("utils."+lineshape),
w, s,
p0= [A0, line, linewidth/2, offset_0],
bounds=([0, line - linewidth/2, linewidth*0.2, offset_0*0.9],
[np.inf, line + linewidth/2, linewidth*2, offset_0*1.1]),)
plt.plot(w, eval("utils."+lineshape)(w, *popt)/1e8, color=c)
return popt
except (RuntimeError, ValueError) as e:
print(f"[eleos] Failed to fit CH4 line at {line:.3f} um: {e}")
nanmask = np.isnan(spectrum)
spectrum = interpolate_nans(wavelengths, spectrum) # linearly interpolate nans
spectrum *= 1e8 # numerical stability
lines = []
i = 0
centre_l = 3.3
centre_r = 3.35
# Extract the N most intense lines from the CH4 spectrum that have a minimum spacing (prevent picking out double lines)
for _, line in constants.CH4_LINES.iterrows():
wavelength = line["wavelength"]
if all(abs(wavelength - existing) > linewidth for existing in lines) and not centre_l < wavelength < centre_r:
lines.append(wavelength)
i += 1
if i >= nlines:
break
# Fit each of those lines
fits = []
for line in lines:
popt = fit(wavelengths, spectrum, line, linewidth)
if popt is None:
continue
fits.append(popt)
# Now fit the extra lines that are too weak to be captured in the line list
# probably not a good idea!
# if extra_lines:
# extralines = [3.478, 3.491, 3.504, 3.517, 3.166, 3.157, 3.149, 3.142, 3.133, 3.124]
# for line in extralines:
# popt = fit(wavelengths, spectrum, line, linewidth, c="green")
# if popt is None:
# continue
# fits.append(popt)
# Set the offset to 0 for subtraction
for i in range(len(fits)):
fits[i][3] = 0
# Add all the lines together to make the final model
model = multifunc(wavelengths, eval("utils."+lineshape), fits)
subtracted = spectrum - model
# Restore original NaNs to the spectrum
subtracted[nanmask] = np.nan
# Remove the massive central peak completely
subtracted[(wavelengths > centre_l) & (wavelengths < centre_r)] = np.nan
model[(wavelengths > centre_l) & (wavelengths < centre_r)] = np.nan
subtracted /= 1e8
model /= 1e8
if return_model:
return subtracted, model
else:
return subtracted
[docs]
def zonal_average_tiles(filepaths, lat_width, error_scale=1, rmsd_threshold=10, filters=None):
"""Get the zonal averages from a set of JWST navigated cubes.
Args:
filepaths (List(str)): List of filepaths to tiles to use in the averaging
lat_width (float): The width of the latitude bins. The given latitude is in the centre of the bin.
error_scale (float): Multiply the average error by this amount (unused so far)
rmsd_threshold (int): Reject any spaxels that are in the top x percentile for root-mean-square deviation from the median
filters (dict): Apply any filters to remove spaxels before averaging. Format is key=planetmapper backplane name (str)
and value=(min, max). For example, filters={'EMISSION':(0, 70)} will reject any spaxels with emission angle greater than 70
Returns:
np.ndarray: Wavelength grid, temporary - will be added to df later
pandas.DataFrame: DataFrame containing the zonal means and information for .spx files.
Columns are spectrum, error (np.ndarrays), phase, emission, azimuth, lon (floats), num_spaxels (int).
Use df.index to get latitudes"""
if filters is None:
filters = dict()
def mask_and_reshape(arr, mask, shape=(-1,)):
return (arr * mask).reshape(*shape)
def compute_rmsd(group):
# Stack the spectra for this group
stacked = np.stack(group['spectrum'].values)
# Compute the median spectrum
median = np.nanmedian(stacked, axis=0)
# Compute RMSD for each row
# This results in an array of shape (n_rows,)
rmsd = np.sqrt(np.nanmean((stacked - median)**2, axis=1))
# Return the group with an additional column
group = group.copy()
group['rmsd'] = rmsd
return group
def nanmean_column(group, column):
return np.nanmean(np.stack(group[column].values), axis=0)
# Dicts to hold individual spaxels, final zonal spectra, various angles
spaxels = pd.DataFrame(columns=["spectrum", "error", "lat", "phase", "emission", "azimuth", "lon", "nans"])
a = 0
# Iterate over each filepath to bin the spaxels by latitude
for tile in filepaths:
print("Processing ", tile)
# Get planetmapper Observation object
obs = planetmapper.Observation(tile)
obs = add_error_cubes(obs)
# Create the mask from filters
mask = np.ones_like(obs.data[0]).astype(bool)
for bp_name, bounds in filters.items():
backplane = obs.get_backplane_img(bp_name)
single_mask = (bounds[0] < backplane) & (backplane < bounds[1])
mask &= single_mask
lonimg = obs.get_lon_img()
emiimg = obs.get_emission_angle_img()
aziimg = obs.get_azimuth_angle_img()
phaimg = obs.get_phase_angle_img()
# Bin each spaxel into latitude bins
lat_bp = obs.get_lat_img()
for l in np.arange(-90, 90, lat_width):
lat_mask = (l - lat_width/2 < lat_bp) & (l + lat_width/2 > lat_bp)
final_mask = lat_mask * mask
if not (final_mask).any():
continue
#print(l)
data = mask_and_reshape(obs.data, final_mask, shape=(obs.data.shape[0], -1))
err = mask_and_reshape(obs.error, final_mask, shape=(obs.data.shape[0], -1))
lons = mask_and_reshape(lonimg, final_mask)
emis = mask_and_reshape(emiimg, final_mask)
azim = mask_and_reshape(aziimg, final_mask)
phas = mask_and_reshape(phaimg, final_mask)
# Remove any all-zero or all-NaN spaxels
for i in range(data.shape[1]):
if not ((data[:,i] < 1e-10) | (np.isnan(data[:,i]))).all():
spaxels.loc[a] = [data[:,i].copy(), err[:, i].copy(), l, phas[i], emis[i], azim[i], lons[i], np.sum(np.isnan(data[:,i]))]
a += 1
# Remove any spaxels with a really large number of NaNs (basically the very edges of the cube)
spaxels.drop(spaxels[spaxels.nans > 200].index, inplace=True)
# Remove any spaxels that are a certain threshold away from the root-mean-square deviation from the median
spaxels = spaxels.groupby('lat', group_keys=False).apply(compute_rmsd)
spaxels.drop(spaxels[(spaxels.rmsd > np.percentile(spaxels.rmsd, 100-rmsd_threshold))].index, inplace=True)
# Calculate final zonal spectra
z = spaxels.groupby('lat')
zonal = z.mean()
zonal['spectrum'] = z.apply(nanmean_column, "spectrum")
zonal['error'] = z.apply(nanmean_column, "error")
zonal["num_spaxels"] = z.apply(len)
del zonal["nans"], zonal["rmsd"]
return obs.get_wavelengths_from_header(), zonal
[docs]
def groupby(filepaths, by="EMISSION", binsize=2, binstart=0, binend=90, rmsd_threshold=10, filters=None):
"""Group spaxels from a set of JWST navigated cubes by a given backplane value.
Args:
filepaths (List(str)): List of filepaths to tiles to use
by (str): The planetmapper backplane name to group by
binsize (float): The size of the bins to group by
binstart (float): The start of the first bin (if None then use min of backplane)
binend (float): The end of the last bin (if None then use max of backplane)
rmsd_threshold (int): Reject any spaxels that are in the top x percentile for root-mean-square deviation from the median
filters (dict): Apply any filters to remove spaxels before averaging. Format is key=planetmapper backplane name (str)
and value=(min, max). For example, filters={'EMISSION':(0, 70)} will accept only spaxels with emission angles between 0 and 70
Returns:
pd.DataFrame: Indexed with bin midpoints, columns are wavelength, spectrum, error, num_spaxels, **average value of backplanes
"""
def average_dict(data):
result = {}
for key, lists in data.items():
# Flatten all nested lists into one
flat = [x for sub in lists for x in sub]
result[key] = sum(flat) / len(flat)
return result
if filters is None:
filters = dict()
# Convert the list of tiles into a single dataframe of spaxels
spaxels = flatten_tiles(filepaths)
spaxels["spectrum"] = spaxels["spectrum"].apply(lambda x: np.where(np.array(x) < 0, np.nan, x))
# Check valid backplane name
if by not in spaxels.columns:
raise ValueError(f"Backplane {by} not found in spaxels. Available backplanes are: {list(spaxels.columns)}")
# Set default bin edges
if binstart is None:
binstart = np.nanmin(spaxels[by])
if binend is None:
binend = np.nanmax(spaxels[by])
if binstart > binend:
binend, binstart = binstart, binend
# Apply any filters to remove spaxels before averaging
for filter_bp, bounds in filters.items():
if filter_bp not in spaxels.columns:
raise ValueError(f"Backplane {filter_bp} not found in spaxels. Available backplanes are: {list(spaxels.columns)}")
spaxels = spaxels[(spaxels[filter_bp] > bounds[0]) & (spaxels[filter_bp] < bounds[1])]
# Bin the spaxels by the given backplane
cut = pd.cut(spaxels[by], np.arange(binstart, binend, binsize))
grouped = spaxels.groupby(cut)
nonzero = 0
for name, df in grouped:
nonzero += len(df) > 0
# Create dataframe to store final output
out = pd.DataFrame()
# Iterate over each bin
for name, df in grouped:
if len(df) == 0:
continue
# Add column to allow grouping by wavelength setting
df["_wl_key"] = df["wavelength"].apply(lambda x: x[0])
filtergrouped = df.groupby("_wl_key")
# Lists to store the individual spectra to combine multiple gratings
wls = []
spec = []
errs = []
extras = defaultdict(list)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
# Iterate over each unique wavelength grid (ie. each filter/grating combination)
n = 0
for wl, ddf in filtergrouped:
print(f"Processing {name} group starting at {wl:2f} um with {len(ddf)} spaxels")
n += len(ddf)
# Calculate group median
spectra = np.stack(ddf["spectrum"].values)
median = np.nanmedian(spectra, axis=0)
# Check for all-nan spectrum and skip if so
if np.all(np.isnan(median)):
print("Skipping all-nan bin")
continue
# Caclulate rmsd and remove outliers
ddf["rmsd"] = np.sqrt(np.nanmean((spectra - median)**2, axis=1))
filtered_ddf = ddf.drop(ddf[(ddf.rmsd > np.percentile(ddf.rmsd, 100-rmsd_threshold))].index)
print(f"Rejected {len(ddf) - len(filtered_ddf)} outliers out of {len(ddf)} spaxels")
# Add to list of spectra to combine multiple gratings
wls.append(ddf.iloc[0].wavelength)
spec.append(np.nanmean(np.stack(filtered_ddf["spectrum"].values), axis=0))
errs.append(np.nanstd(np.stack(filtered_ddf["error"].values), axis=0))
for col in [x for x in df.columns if x not in ["wavelength", "spectrum", "error", "_wl_key"]]:
extras[col].append(list(ddf[col]))
# Combine the multiple filter/grating combinations
w,s,e = combine_multiple_spectra(*zip(wls, spec, errs))
extra = average_dict(extras)
out = pd.concat([out, pd.DataFrame([{"wavelength":w, "spectrum":s, "error":e, "num_spaxels":n} | extra], index=[name.mid])])
return out
[docs]
def flatten_tiles(filepaths):
"""Flatten a set of navigated cubes into a single dataframe of spaxels"""
# Dataframe to hold individual spaxels, final zonal spectra, various angles
spaxels = pd.DataFrame()
# Iterate over each tile in the mosaic
for tile in filepaths:
print("Processing ", tile)
pd.set_option('display.max_columns', 6)
pd.set_option('display.width', 120)
# Get planetmapper Observation object
obs = planetmapper.Observation(tile)
obs = add_error_cubes(obs)
# Rearrange spaxels into list such that:
# obs.data[:,a,b] --> pixels[b + obs.data.shape[2]*a]
df = pd.DataFrame()
wavelength = obs.get_wavelengths_from_header()
spectra = np.moveaxis(obs.data, 0, -1).reshape(-1, obs.data.shape[0])
errors = np.moveaxis(obs.error, 0, -1).reshape(-1, obs.data.shape[0])
df["wavelength"] = [wavelength] * spectra.shape[0]
df["spectrum"] = spectra.tolist()
df["error"] = errors.tolist()
for backplane in obs.backplanes.keys():
df[backplane] = obs.get_backplane_img(backplane).flatten()
spaxels = pd.concat([spaxels, df], ignore_index=True)
return spaxels
[docs]
def multiple_cube_average(cubes):
return np.nanmean(np.array(cubes), axis=(0,2,3))
[docs]
def get_observations(pattern, add_errors=True):
"""Get a list of planetmapper.Observation objects from a glob-style file pattern
Args:
pattern (str): The pattern to use to search for .fits files
add_errors (bool): Whether to add the error cubes as an attribute (Observation.error)
Returns:
List[planetmapper.Observation]: The found observations"""
fps = sorted(glob.glob(pattern))
obs = [planetmapper.Observation(fp) for fp in fps]
if add_error_cubes:
obs = add_error_cubes(obs)
return obs
[docs]
def add_error_cubes(observations):
"""Add the error cube from a JWST observation to a list (or a single) planetmapper.Observation object,
accessible using the new Observation.error attribute.
Args:
observations: planetmapper.Observation or List[planetmapper.Observation]: The observation(s) to add errors for
Returns:
observations: planetmapper.Observation or List[planetmapper.Observation]: The observation(s) with added errors
"""
if isinstance(observations, planetmapper.Observation):
observations = [observations,]
return_single = True
else:
return_single = False
for obs in observations:
fp = obs.path
hdul = fits.open(fp)
obs.error = hdul['ERR'].data
if return_single:
return observations[0]
else:
return observations
[docs]
def get_single_spectra(file_pattern,
out_filename=None,
max_emission=99,
margins=0,
pct_error=0,
num_points=None,
min_wl=-999,
max_wl=999):
"""Get a single spectrum from multiple observations, with options to restrict emission angle, downsample, and restrict wavelength range.
Saves the output to a .spx file.
Args:
file_pattern (str): File pattern to match observation files.
out_filename (str): Output filename template for saving the processed spectrum (if None then don't save).
This can contain any of the parameters passed into the function surrounded
by curly braces. eg. "spectra_n{num_points}.spx" will become "spectra_n80.spx"
if num_points=80 is passed in.
max_emission (float): Maximum allowable emission angle.
margins (int): Number of pixels around the edge of the image to remove.
pct_error (float): Flat percentage error to apply to the spectrum.
num_points (int): Number of points to downsample the spectrum to.
min_wl (float): Minimum wavelength in the spectrum.
max_wl (float): Maximum wavelength in the spectrum.
Returns:
np.ndarray: The wavelengths in the new .spx file
np.ndarray: The radiances in the new spx file (returns MJy/sr, writes W/cm2/sr/um to the .spx file)
np.ndarray: The radiance error in the new spx file (same units as radiance)
"""
params = locals().copy()
observations = get_observations(file_pattern)
wavelengths = observations[0].get_wavelengths_from_header()
cubes = []
weights = []
other = defaultdict(list)
for obs in observations:
mask = np.full(obs.data.shape, np.nan)
mask[:, obs.get_emission_angle_img() < max_emission] = 1
mask = margin_trim(mask, margin_size=margins)
if np.all(np.isnan(mask)):
continue
cubes.append(obs.data * mask)
weights.append(np.sum(np.where(~np.isnan(mask), 1, 0)))
other["lat"].append(np.nanmean(obs.get_lat_img() * mask))
other["lon"].append(np.nanmean(obs.get_lon_img() * mask))
other["phase"].append(np.nanmean(obs.get_phase_angle_img() * mask))
other["emission"].append(np.nanmean(obs.get_emission_angle_img() * mask))
other["azimuth"].append(np.nanmean(obs.get_azimuth_angle_img() * mask))
w = wavelengths
s = multiple_cube_average(cubes)
s, w = wavelength_select(w, s, min_wl=min_wl, max_wl=max_wl)
if num_points is not None:
s, w = downsample_spectrum(s, w, num_points=num_points)
for name, values in other.items():
other[name] = utils.nanaverage(values, weights)
err = s * pct_error
if out_filename is not None:
spx.write(out_filename.format(**params),
spectrum=s,
error=err,
wavelengths=w,
fwhm=0,
**other)
return w, s, err
[docs]
def combine_multiple_spectra(*spectra_units):
"""
Combine multiple spectra with errors. Preserves original wavelength resolution in
non-overlapping regions and uses weighted average (inverse variance) in overlaps.
Args:
*spectra_units: Tuples of (wavelength, spectrum, error), where each element is a 1D array.
Returns:
wavelengths (np.array): Stitched wavelength grid
spectrum (np.array): Combined spectral values
error (np.array): Combined errors
"""
segments = []
spectra = list(spectra_units)
spectra.sort(key=lambda t: t[0][0]) # Sort by wavelength start
for i, (wl_i, f_i, e_i) in enumerate(spectra):
wmin_i, wmax_i = wl_i.min(), wl_i.max()
overlapping_segments = []
for j, (wl_j, f_j, e_j) in enumerate(spectra):
if i == j:
continue
wmin_j, wmax_j = wl_j.min(), wl_j.max()
if wmax_j > wmin_i and wmin_j < wmax_i:
w_overlap_min = max(wmin_i, wmin_j)
w_overlap_max = min(wmax_i, wmax_j)
# Choose higher-res grid in overlap
res_i = np.median(np.diff(wl_i[(wl_i >= w_overlap_min) & (wl_i <= w_overlap_max)]))
res_j = np.median(np.diff(wl_j[(wl_j >= w_overlap_min) & (wl_j <= w_overlap_max)]))
grid = wl_i if res_i <= res_j else wl_j
wl_common = grid[(grid >= w_overlap_min) & (grid <= w_overlap_max)]
# Interpolate flux and error
fi = interp1d(wl_i, f_i, bounds_error=False, fill_value=np.nan)(wl_common)
fj = interp1d(wl_j, f_j, bounds_error=False, fill_value=np.nan)(wl_common)
ei = interp1d(wl_i, e_i, bounds_error=False, fill_value=np.nan)(wl_common)
ej = interp1d(wl_j, e_j, bounds_error=False, fill_value=np.nan)(wl_common)
# Inverse-variance weighted average
with np.errstate(divide='ignore', invalid='ignore'):
wi = 1 / (ei**2)
wj = 1 / (ej**2)
weights_sum = wi + wj
flux_comb = np.nansum([wi * fi, wj * fj], axis=0) / weights_sum
error_comb = np.sqrt(1 / weights_sum)
overlapping_segments.append((wl_common, flux_comb, error_comb))
# Remove overlap from i-th spectrum
mask = np.full(wl_i.shape, True)
for wl_o, _, _ in overlapping_segments:
mask &= ~((wl_i >= wl_o.min()) & (wl_i <= wl_o.max()))
# Add unique portion
if np.any(mask):
segments.append((wl_i[mask], f_i[mask], e_i[mask]))
# Add overlapping portions
segments.extend(overlapping_segments)
# Stitch all segments together
all_wl = np.concatenate([seg[0] for seg in segments])
all_flux = np.concatenate([seg[1] for seg in segments])
all_err = np.concatenate([seg[2] for seg in segments])
idx = np.argsort(all_wl)
return all_wl[idx], all_flux[idx], all_err[idx]
[docs]
def convolve_gaussian(wavelength, data, R, lambda0):
"""Convolve a spectrum with a Gaussian, assuming constant R. In reality resolution varies as a function of wavelength and
depends on the filters used (for JWST/NIRSpec see https://jwst-docs.stsci.edu/jwst-near-infrared-spectrograph/nirspec-instrumentation/nirspec-dispersers-and-filters)
Args:
wavelength (np.ndarray): Wavelength array (same length as data).
data (np.ndarray): The data to be convolved.
R (float): Resolving power of the instrument
lambda0 (float): The wavelength at which R is defined. For G395H this is ~4um. For G235H this is ~2.4um.
Returns:
np.ndarray: The convolved data.
"""
wavelength = np.asarray(wavelength)
# FWHM and sigma in wavelength units
fwhm = lambda0 / R
sigma = fwhm / 2.35482
# Build Gaussian kernel in wavelength space
# Kernel extends to 4 sigma for good truncation
dw = np.median(np.diff(wavelength)) # wavelength step
half_width = int(np.ceil(4 * sigma / dw))
x = np.arange(-half_width, half_width + 1) * dw
kernel = np.exp(-(x**2) / (2 * sigma**2))
kernel /= kernel.sum() # normalize so area = 1
# Convolve using FFT for efficiency
convolved = fftconvolve(data, kernel, mode='same')
return convolved
[docs]
def quickview(wavelength, spectrum, errors=None, log=True, block=True):
"""Quickly view a spectrum using matplotlib.
Args:
wavelength (np.ndarray): The wavelength grid
spectrum (np.ndarray): The spectral data
errors (np.ndarray): The error on the spectral data (optional)
log (bool): Whether to use a logarithmic scale for the y-axis
block (bool): Whether to block the execution until the plot is closed (default True
"""
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.plot(wavelength, spectrum, label='Spectrum')
if errors is not None:
plt.fill_between(wavelength, spectrum - errors, spectrum + errors, alpha=0.3, label='Error')
if log:
plt.yscale('log')
plt.xlabel('Wavelength (microns)')
plt.ylabel('Radiance')
plt.legend()
plt.grid()
plt.show(block=block)