Source code for pjpipe.utils.utils

import copy
import functools
import gc
import inspect
import logging
import os
import warnings

import numpy as np
from astropy.convolution import convolve_fft
from astropy.io import fits
from astropy.nddata.bitmask import interpret_bit_flags, bitfield_to_boolean_mask
from astropy.stats import sigma_clipped_stats, SigmaClip
from astropy.table import Table
from astropy.wcs import WCS
from photutils.segmentation import detect_threshold, detect_sources
from reproject import reproject_interp
from reproject.mosaicking.subset_array import ReprojectedArraySubset
from scipy.interpolate import RegularGridInterpolator
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels.dqflags import pixel
from stdatamodels import util

from .. import __version__

try:
    import tomllib
except ModuleNotFoundError:
    import tomli as tomllib

# Useful values

PIXEL_SCALE_NAMES = ["XPIXSIZE", "CDELT1", "CD1_1", "PIXELSCL"]

# Pixel scales
jwst_pixel_scales = {
    "miri": 0.11,
    "nircam_long": 0.063,
    "nircam_short": 0.031,
}

# All NIRCAM bands
nircam_bands = [
    "F070W",
    "F090W",
    "F115W",
    "F140M",
    "F150W",
    "F162M",
    "F164N",
    "F150W2",
    "F182M",
    "F187N",
    "F200W",
    "F210M",
    "F212N",
    "F250M",
    "F277W",
    "F300M",
    "F322W2",
    "F323N",
    "F335M",
    "F356W",
    "F360M",
    "F405N",
    "F410M",
    "F430M",
    "F444W",
    "F460M",
    "F466N",
    "F470N",
    "F480M",
]

# All MIRI bands
miri_bands = [
    "F560W",
    "F770W",
    "F1000W",
    "F1130W",
    "F1280W",
    "F1500W",
    "F1800W",
    "F2100W",
    "F2550W",
]

# FWHM of bands in pixels
fwhms_pix = {
    # NIRCAM
    "F070W": 0.987,
    "F090W": 1.103,
    "F115W": 1.298,
    "F140M": 1.553,
    "F150W": 1.628,
    "F162M": 1.770,
    "F164N": 1.801,
    "F150W2": 1.494,
    "F182M": 1.990,
    "F187N": 2.060,
    "F200W": 2.141,
    "F210M": 2.304,
    "F212N": 2.341,
    "F250M": 1.340,
    "F277W": 1.444,
    "F300M": 1.585,
    "F322W2": 1.547,
    "F323N": 1.711,
    "F335M": 1.760,
    "F356W": 1.830,
    "F360M": 1.901,
    "F405N": 2.165,
    "F410M": 2.179,
    "F430M": 2.300,
    "F444W": 2.302,
    "F460M": 2.459,
    "F466N": 2.507,
    "F470N": 2.535,
    "F480M": 2.574,
    # MIRI
    "F560W": 1.636,
    "F770W": 2.187,
    "F1000W": 2.888,
    "F1130W": 3.318,
    "F1280W": 3.713,
    "F1500W": 4.354,
    "F1800W": 5.224,
    "F2100W": 5.989,
    "F2550W": 7.312,
}

band_exts = {
    "nircam": "nrc*",
    "miri": "mirimage",
}

log = logging.getLogger("stpipe")
log.addHandler(logging.NullHandler())


[docs] def get_pixscale(hdu): """Get pixel scale from header. Checks HDU header and returns a pixel scale Args: hdu: hdu to get pixel scale for """ for pixel_keyword in PIXEL_SCALE_NAMES: try: try: pix_scale = np.abs(float(hdu.header[pixel_keyword])) except ValueError: continue if pixel_keyword in ["CDELT1", "CD1_1"]: pix_scale = WCS(hdu.header).proj_plane_pixel_scales()[0].value * 3600 # pix_scale *= 3600 return pix_scale except KeyError: pass raise Warning("No pixel scale found")
[docs] def load_toml(filename): """Open a .toml file Args: filename (str): Path to toml file """ with open(filename, "rb") as f: toml_dict = tomllib.load(f) return toml_dict
[docs] def get_band_type( band, short_long_nircam=False, ): """Get the instrument type from the band name Args: band (str): Name of band short_long_nircam (bool): Whether to distinguish between short/long NIRCam bands. Defaults to False """ if band in miri_bands: band_type = "miri" elif band in nircam_bands: band_type = "nircam" else: raise ValueError(f"band {band} unknown") if not short_long_nircam: return band_type else: if band_type in ["nircam"]: if int(band[1:4]) <= 212: short_long = "nircam_short" else: short_long = "nircam_long" band_type = "nircam" else: short_long = copy.deepcopy(band_type) return band_type, short_long
[docs] def get_band_ext(band): """Get the specific extension (e.g. mirimage) for a band""" band_type = get_band_type(band) band_ext = band_exts[band_type] return band_ext
[docs] def get_default_args(func): """Pull the default arguments from a function""" signature = inspect.signature(func) return { k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty }
[docs] def get_kws( parameters, func, band, target, max_level=None, ): """Set up kwarg dict for a function, looping over band and target Args: parameters: Dictionary of parameters func: Function to set the parameters for band: Band to pull band-specific parameters for target: Target to pull target-specific parameters for max_level: How far to recurse down the dictionary. Defaults to None, which will recurse all the way down """ args = get_default_args(func) func_kws = {} for arg in args: if arg in parameters: arg_val = parse_parameter_dict( parameters=parameters, key=arg, band=band, target=target, max_level=max_level, ) if arg_val == "VAL_NOT_FOUND": arg_val = args[arg] else: arg_val = args[arg] func_kws[arg] = arg_val return func_kws
[docs] def parse_parameter_dict( parameters, key, band, target, max_level=None, ): """Pull values out of a parameter dictionary Args: parameters (dict): Dictionary of parameters and associated values key (str): Particular key in parameter_dict to consider band (str): JWST band, to parse out band type and potentially per-band values target (str): JWST target, for very specific values max_level: Maximum level to recurse down. Defaults to None, which will go until it finds something that's not a dictionary """ if max_level is None: max_level = np.inf value = parameters[key] band_type, short_long = get_band_type( band, short_long_nircam=True, ) pixel_scale = jwst_pixel_scales[short_long] found_value = False level = 0 while level < max_level and not found_value: if isinstance(value, dict): # Define a priority here. It goes: # * target # * band # * nircam_short/nircam_long # * nircam/miri if target in value: value = value[target] elif band in value: value = value[band] elif band_type == "nircam" and short_long in value: value = value[short_long] elif band_type in value: value = value[band_type] else: value = "VAL_NOT_FOUND" level += 1 if not isinstance(value, dict): found_value = True # Finally, if we have a string with a 'pix' in there, we need to convert to arcsec if isinstance(value, str): if "pix" in value: value = float(value.strip("pix")) * pixel_scale return value
[docs] def attribute_setter( pipeobj, parameters, band, target, ): """Set attributes for a function Args: pipeobj: Function/class to set parameters for parameters: Dictionary of parameters to set band: Band to pull band-specific parameters for target: Target to pull target-specific parameters for """ for key in parameters.keys(): if type(parameters[key]) is dict: for subkey in parameters[key]: value = parse_parameter_dict( parameters=parameters[key], key=subkey, band=band, target=target, ) if value == "VAL_NOT_FOUND": continue recursive_setattr( pipeobj, ".".join([key, subkey]), value, ) else: value = parse_parameter_dict( parameters=parameters, key=key, band=band, target=target, ) if value == "VAL_NOT_FOUND": continue recursive_setattr( pipeobj, key, value, ) return pipeobj
[docs] def recursive_setattr( f, attribute, value, protected=False, ): """Set potentially recursive function attributes. This is needed for the JWST pipeline steps, which have levels to them Args: f: Function to consider attribute: Attribute to consider value: Value to set protected: If a function is protected, this won't strip out the leading underscore """ pre, _, post = attribute.rpartition(".") if pre: pre_exists = True else: pre_exists = False if protected: post = "_" + post return setattr(recursive_getattr(f, pre) if pre_exists else f, post, value)
def recursive_getattr( f, attribute, *args, ): """Get potentially recursive function attributes. This is needed for the JWST pipeline steps, which have levels to them Args: f: Function to consider attribute: Attribute to consider args: Named arguments """ def _getattr(f, attribute): return getattr(f, attribute, *args) return functools.reduce(_getattr, [f] + attribute.split("."))
[docs] def get_obs_table( files, check_bgr=False, check_type="parallel_off", background_name="off", ): """Pull necessary info out of fits headers""" tab = Table( names=[ "File", "Type", "Obs_ID", "Filter", "Start", "Exptime", "Objname", "Program", "Array", ], dtype=[ str, str, str, str, str, float, str, str, str, ], ) for f in files: tab.add_row( parse_fits_to_table( f, check_bgr=check_bgr, check_type=check_type, background_name=background_name, ) ) return tab
[docs] def parse_fits_to_table( file, check_bgr=False, check_type="parallel_off", background_name="off", ): """Pull necessary info out of fits headers Args: file (str): File to get info for check_bgr (bool): Whether to check if this is a science or background observation (in the MIRI case) check_type (str): How to check if background observation. Options are 'parallel_off', which will use the filename to see if it's a parallel observation with NIRCAM, or 'check_in_name', which will use the observation name to check, matching against 'background_name'. Defaults to 'parallel_off' background_name (str): Name to indicate background observation. Defaults to 'off'. """ # Figure out if we're a background observation or not f_type = "sci" if check_bgr: if check_type == "parallel_off": file_split = os.path.split(file)[-1] if file_split.split("_")[1][2] == "2": f_type = "bgr" elif check_type == "check_in_name": with datamodels.open(file) as im: if background_name in im.meta.target.proposer_name.lower(): f_type = "bgr" else: raise Warning(f"check_type {check_type} not known") # Pull out data we need from header with datamodels.open(file) as im: obs_n = im.meta.observation.observation_number obs_filter = im.meta.instrument.filter obs_date = im.meta.observation.date_beg obs_duration = im.meta.exposure.duration obs_label = im.meta.observation.observation_label.lower() obs_program = im.meta.observation.program_number array_name = im.meta.subarray.name.lower().strip() return ( file, f_type, obs_n, obs_filter, obs_date, obs_duration, obs_label, obs_program, array_name, )
[docs] def get_dq_bit_mask( dq, bit_flags="~DO_NOT_USE+NON_SCIENCE", ): """Get a DQ bit mask from an input image Args: dq: DQ array bit_flags: Bit flags to get mask for. Defaults to only get science pixels """ dq_bits = interpret_bit_flags(bit_flags=bit_flags, flag_name_map=pixel) dq_bit_mask = bitfield_to_boolean_mask( dq.astype(np.uint8), dq_bits, good_mask_value=0, dtype=np.uint8 ) return dq_bit_mask
[docs] def make_source_mask( data, mask=None, nsigma=3, npixels=3, dilate_size=11, sigclip_iters=5, ): """Make a source mask from segmentation image""" sc = SigmaClip( sigma=nsigma, maxiters=sigclip_iters, ) threshold = detect_threshold( data, mask=mask, nsigma=nsigma, sigma_clip=sc, ) segment_map = detect_sources( data, threshold, npixels=npixels, ) # If sources are detected, we can make a segmentation mask, else fall back to 0 array try: mask = segment_map.make_source_mask(size=dilate_size) except AttributeError: mask = np.zeros(data.shape, dtype=bool) return mask
[docs] def sigma_clip( data, dq_mask=None, sigma=1.5, n_pixels=5, max_iterations=20, ): """Get sigma-clipped statistics for data""" with warnings.catch_warnings(): warnings.simplefilter("ignore") mask = make_source_mask(data, mask=dq_mask, nsigma=sigma, npixels=n_pixels) if dq_mask is not None: mask = np.logical_or(mask, dq_mask) mean, median, std_dev = sigma_clipped_stats( data, mask=mask, sigma=sigma, maxiters=max_iterations ) return mean, median, std_dev
[docs] def reproject_image( file, optimal_wcs, optimal_shape, hdu_type="data", do_sigma_clip=False, stacked_image=False, do_level_data=False, ): """Reproject an image to an optimal WCS Args: file: File to reproject optimal_wcs: Optimal WCS for input image stack optimal_shape: Optimal shape for input image stack hdu_type: Type of HDU. Can either be 'data' or 'var_rnoise' do_sigma_clip: Whether to perform sigma-clipping or not. Defaults to False stacked_image: Stacked image or not? Defaults to False do_level_data: Whether to level between amplifiers or not. Defaults to False """ if not stacked_image: with datamodels.open(file) as hdu: dq_bit_mask = get_dq_bit_mask(hdu.dq) wcs = hdu.meta.wcs.to_fits_sip() w_in = WCS(wcs) # Level data (but not in subarray mode) if "sub" not in hdu.meta.subarray.name.lower() and do_level_data and hdu_type == "data": hdu.data = level_data(hdu) if hdu_type == "data": data = copy.deepcopy(hdu.data) elif hdu_type == "var_rnoise": data = copy.deepcopy(hdu.var_rnoise) else: raise Warning(f"Unsure how to deal with hdu_type {hdu_type}") else: with fits.open(file) as hdu: data = copy.deepcopy(hdu["SCI"].data) wcs = hdu["SCI"].header w_in = WCS(wcs) dq_bit_mask = None sig_mask = None if do_sigma_clip: sig_mask = make_source_mask( data, mask=dq_bit_mask, dilate_size=7, ) sig_mask = sig_mask.astype(int) data[data == 0] = np.nan # Find the minimal shape for the reprojection. This is from the astropy reproject routines ny, nx = data.shape xc = np.array([-0.5, nx - 0.5, nx - 0.5, -0.5]) yc = np.array([-0.5, -0.5, ny - 0.5, ny - 0.5]) xc_out, yc_out = optimal_wcs.world_to_pixel(w_in.pixel_to_world(xc, yc)) if np.any(np.isnan(xc_out)) or np.any(np.isnan(yc_out)): imin = 0 imax = optimal_shape[1] jmin = 0 jmax = optimal_shape[0] else: imin = max(0, int(np.floor(xc_out.min() + 0.5))) imax = min(optimal_shape[1], int(np.ceil(xc_out.max() + 0.5))) jmin = max(0, int(np.floor(yc_out.min() + 0.5))) jmax = min(optimal_shape[0], int(np.ceil(yc_out.max() + 0.5))) if imax < imin or jmax < jmin: return wcs_out_indiv = optimal_wcs[jmin:jmax, imin:imax] shape_out_indiv = (jmax - jmin, imax - imin) data_reproj_small = reproject_interp( (data, wcs), output_projection=wcs_out_indiv, shape_out=shape_out_indiv, return_footprint=False, ) # Mask out bad DQ, but only for unstacked images if not stacked_image: dq_reproj_small = reproject_interp( (dq_bit_mask, wcs), output_projection=wcs_out_indiv, shape_out=shape_out_indiv, return_footprint=False, order="nearest-neighbor", ) data_reproj_small[dq_reproj_small == 1] = np.nan if do_sigma_clip: sig_mask_reproj_small = reproject_interp( (sig_mask, wcs), output_projection=wcs_out_indiv, shape_out=shape_out_indiv, return_footprint=False, order="nearest-neighbor", ) data_reproj_small[sig_mask_reproj_small == 1] = np.nan footprint = np.ones_like(data_reproj_small) footprint[ np.logical_or(data_reproj_small == 0, ~np.isfinite(data_reproj_small)) ] = 0 data_array = ReprojectedArraySubset( data_reproj_small, footprint, imin, imax, jmin, jmax ) del hdu gc.collect() return data_array
[docs] def do_jwst_convolution( file_in, file_out, file_kernel, blank_zeros=True, output_grid=None, ): """ Convolves input image with an input kernel, and writes to disk. Will also process errors and do reprojection, if specified Args: file_in: Path to image file file_out: Path to output file file_kernel: Path to kernel for convolution blank_zeros: If True, then all zero values will be set to NaNs. Defaults to True output_grid: None (no reprojection to be done) or tuple (wcs, shape) defining the grid for reprojection. Defaults to None """ with fits.open(file_kernel) as kernel_hdu: kernel_pix_scale = get_pixscale(kernel_hdu[0]) # Note the shape and grid of the kernel as input kernel_data = kernel_hdu[0].data kernel_hdu_length = kernel_hdu[0].data.shape[0] original_central_pixel = (kernel_hdu_length - 1) / 2 original_grid = ( np.arange(kernel_hdu_length) - original_central_pixel ) * kernel_pix_scale with fits.open(file_in) as image_hdu: if blank_zeros: # make sure that all zero values were set to NaNs, which # astropy convolution handles with interpolation image_hdu["ERR"].data[(image_hdu["SCI"].data == 0)] = np.nan image_hdu["SCI"].data[(image_hdu["SCI"].data == 0)] = np.nan image_pix_scale = get_pixscale(image_hdu["SCI"]) # Calculate kernel size after interpolating to the image pixel # scale. Because sometimes there's a little pixel scale rounding # error, subtract a little bit off the optimum size (Tom # Williams). interpolate_kernel_size = ( np.floor(kernel_hdu_length * kernel_pix_scale / image_pix_scale) - 2 ) # Ensure the kernel has a central pixel if interpolate_kernel_size % 2 == 0: interpolate_kernel_size -= 1 # Define a new coordinate grid onto which to project the kernel # but using the pixel scale of the image new_central_pixel = (interpolate_kernel_size - 1) / 2 new_grid = ( np.arange(interpolate_kernel_size) - new_central_pixel ) * image_pix_scale x_coords_new, y_coords_new = np.meshgrid(new_grid, new_grid) # Do the reprojection from the original kernel grid onto the new # grid with pixel scale matched to the image grid_interpolated = RegularGridInterpolator( (original_grid, original_grid), kernel_data, bounds_error=False, fill_value=0.0, ) kernel_interp = grid_interpolated( (x_coords_new.flatten(), y_coords_new.flatten()) ) kernel_interp = kernel_interp.reshape(x_coords_new.shape) # Ensure the interpolated kernel is normalized to 1 kernel_interp = kernel_interp / np.nansum(kernel_interp) # Now with the kernel centered and matched in pixel scale to the # input image use the FFT convolution routine from astropy to # convolve. conv_im = convolve_fft( image_hdu["SCI"].data, kernel_interp, allow_huge=True, preserve_nan=True, fill_value=np.nan, ) # Convolve errors (with kernel**2, do not normalize it). # This, however, doesn't account for covariance between pixels conv_err = np.sqrt( convolve_fft( image_hdu["ERR"].data ** 2, kernel_interp ** 2, preserve_nan=True, allow_huge=True, normalize_kernel=False, ) ) image_hdu["SCI"].data = conv_im image_hdu["ERR"].data = conv_err if output_grid is None: image_hdu.writeto(file_out, overwrite=True) else: # Reprojection to target wcs grid define in output_grid target_wcs, target_shape = output_grid hdulist_out = fits.HDUList([fits.PrimaryHDU(header=image_hdu[0].header)]) repr_data, fp = reproject_interp( (conv_im, image_hdu["SCI"].header), output_projection=target_wcs, shape_out=target_shape, ) fp = fp.astype(bool) repr_data[~fp] = np.nan header = image_hdu["SCI"].header header.update(target_wcs.to_header()) hdulist_out.append(fits.ImageHDU(data=repr_data, header=header, name="SCI")) # Note - this ignores the errors of interpolation and thus the resulting errors might be underestimated repr_err = reproject_interp( (conv_err, image_hdu["SCI"].header), output_projection=target_wcs, shape_out=target_shape, return_footprint=False, ) repr_err[~fp] = np.nan header = image_hdu["ERR"].header hdulist_out.append(fits.ImageHDU(data=repr_err, header=header, name="ERR")) hdulist_out.writeto(file_out, overwrite=True)
[docs] def level_data( im, ): """Level overlaps in NIRCAM amplifiers Args: im: Input datamodel """ data = copy.deepcopy(im.data) quadrant_size = data.shape[1] // 4 dq_mask = get_dq_bit_mask(dq=im.dq) dq_mask = dq_mask | ~np.isfinite(im.data) | ~np.isfinite(im.err) | (im.data == 0) for i in range(3): quad_1 = data[:, i * quadrant_size: (i + 1) * quadrant_size][ :, quadrant_size - 20: ] dq_1 = dq_mask[:, i * quadrant_size: (i + 1) * quadrant_size][ :, quadrant_size - 20: ] quad_2 = data[:, (i + 1) * quadrant_size: (i + 2) * quadrant_size][:, :20] dq_2 = dq_mask[:, (i + 1) * quadrant_size: (i + 2) * quadrant_size][:, :20] quad_1[dq_1] = np.nan quad_2[dq_2] = np.nan with warnings.catch_warnings(): warnings.simplefilter("ignore") med_1 = np.nanmedian( quad_1, axis=1, ) med_2 = np.nanmedian( quad_2, axis=1, ) diff = med_1 - med_2 delta = sigma_clipped_stats(diff, maxiters=None)[1] data[:, (i + 1) * quadrant_size: (i + 2) * quadrant_size] += delta return data
[docs] def save_file(im, out_name, dr_version, ): """Save out an image, adding in useful metadata Args: im: Input JWST datamodel out_name: File to save output to dr_version: Data processing version """ # Save versions both in the metadata, and in fits history im.meta.pjpipe_version = __version__ im.meta.pjpipe_dr_version = dr_version entry = util.create_history_entry(f"PJPIPE VER: {__version__}") im.history.append(entry) entry = util.create_history_entry(f"DATA PROCESSING VER: {dr_version}") im.history.append(entry) im.save(out_name) return True