Source code for pjpipe.lv3.lv3_step

import copy
import gc
import glob
import json
import logging
import os
import shutil
import time

import asdf
import numpy as np
import jwst
from jwst.datamodels import ModelContainer
from jwst.resample import ResampleStep
from stcal.outlier_detection.utils import gwcs_blot
from stcal.resample.utils import compute_mean_pixel_area

# FIXME For newer JWST versions, import ModelLibrary
try:
    from jwst.datamodels import ModelLibrary
except ImportError:
    class ModelLibrary:
        pass

from jwst.pipeline import calwebb_image3
from jwst.skymatch import SkyMatchStep
from jwst.tweakreg import TweakRegStep
from stdatamodels.jwst import datamodels

from ..utils import (
    get_band_type,
    get_band_ext,
    get_obs_table,
    parse_parameter_dict,
    attribute_setter,
    recursive_setattr,
    fwhms_pix,
    save_file,
    get_short_band_name,
)

log = logging.getLogger(__name__)

BGR_CHECK_TYPES = [
    "parallel_off",
    "check_in_name",
    "filename",
]


[docs] class Lv3Step: def __init__( self, target, band, in_dir, out_dir, dr_version, is_bgr, step_ext, procs, tweakreg_degroup_nircam_modules=False, tweakreg_degroup_nircam_short_chips=False, tweakreg_group_dithers=None, tweakreg_degroup_dithers=None, skymatch_group_dithers=None, skymatch_degroup_dithers=None, bgr_check_type="parallel_off", bgr_background_name="off", process_bgr_like_science=False, jwst_parameters=None, do_drizzle=False, do_blot=False, blot_fillval=np.nan, overwrite=False, ): """Wrapper around the level 3 JWST pipeline Args: target: Target to consider band: Band to consider in_dir: Input directory out_dir: Output directory dr_version: Data processing version is_bgr: Whether we're processing background observations or not step_ext: .fits extension for the files going into the step procs: Number of processes to run in parallel tweakreg_degroup_nircam_modules: Whether to degroup NIRCam A and B modules. Currently, the WCS is inconsistent between the two, so should probably be set to True if you see "ghosting" in the final mosaic. Defaults to False tweakreg_degroup_nircam_short_chips: Whether to degroup NIRCam short 1/2/3/4 chips. There may be some shifts between these, so should ideally find a shift for each chip. Defaults to False tweakreg_group_dithers: List of 'miri', 'nircam_long', 'nircam_short' of whether to group up dithers for tweakreg. Defaults to None, which will keep at default tweakreg_degroup_dithers: List of 'miri', 'nircam_long', 'nircam_short' of whether to degroup dithers for tweakreg. Defaults to None, which will keep at default. skymatch_group_dithers: List of 'miri', 'nircam_long', 'nircam_short' of whether to group up dithers for skymatch. Defaults to None, which will keep at default skymatch_degroup_dithers: List of 'miri', 'nircam_long', 'nircam_short' of whether to degroup dithers for skymatch. Defaults to None, which will keep at default. bgr_check_type: Method to check if obs is science or background. Options are given by BGR_CHECK_TYPES. Defaults to 'parallel_off' bgr_background_name: If `bgr_check_type` is 'check_in_name' or 'filename', this is the string to match process_bgr_like_science: If True, will process background images as if they are science images. Defaults to False jwst_parameters: Parameter dictionary to pass to the level 2 pipeline. Defaults to None, which will run the observatory defaults do_drizzle: If True, drizzle individual frames to the i2d mosaic WCS after the main pipeline run. Note: creates a lot of files. Defaults to False. do_blot: If True, blot the final i2d mosaic to the detector frame of each exposure, producing one ``*_i2d_blot.fits`` file per exposure. Independent of do_drizzle. Defaults to False. blot_fillval: Fill value for pixels outside the blotted footprint. Defaults to np.nan overwrite: Whether to overwrite or not. Defaults to False """ if jwst_parameters is None: jwst_parameters = {} if tweakreg_group_dithers is None: tweakreg_group_dithers = [] if tweakreg_degroup_dithers is None: tweakreg_degroup_dithers = [] if skymatch_group_dithers is None: skymatch_group_dithers = [] if skymatch_degroup_dithers is None: skymatch_degroup_dithers = [] if is_bgr: bgr_ext = "_bgr" else: bgr_ext = "" self.target = target self.band = band self.in_dir = in_dir self.out_dir = out_dir self.dr_version = dr_version self.is_bgr = is_bgr self.bgr_ext = bgr_ext self.step_ext = step_ext self.procs = procs self.tweakreg_degroup_nircam_modules = tweakreg_degroup_nircam_modules self.tweakreg_degroup_nircam_short_chips = tweakreg_degroup_nircam_short_chips self.tweakreg_group_dithers = tweakreg_group_dithers self.tweakreg_degroup_dithers = tweakreg_degroup_dithers self.skymatch_group_dithers = skymatch_group_dithers self.skymatch_degroup_dithers = skymatch_degroup_dithers self.bgr_check_type = bgr_check_type self.bgr_background_name = bgr_background_name self.process_bgr_like_science = process_bgr_like_science self.jwst_parameters = jwst_parameters self.do_drizzle = do_drizzle self.do_blot = do_blot self.blot_fillval = float(blot_fillval) self.overwrite = overwrite self.band_type = get_band_type(self.band) self.band_ext = get_band_ext(self.band)
[docs] def do_step(self): """Run the level 3 pipeline""" if self.overwrite: shutil.rmtree(self.out_dir) os.system(f"rm -rf {os.path.join(self.in_dir, '*.json')}") if not os.path.exists(self.out_dir): os.makedirs(self.out_dir) # Check if we've already run the step step_complete_file = os.path.join( self.out_dir, "lv3_step_complete.txt", ) if os.path.exists(step_complete_file): log.info("Step already run") return True files = glob.glob( os.path.join( self.in_dir, f"*_{self.step_ext}.fits", ) ) files.sort() asn_file = self.create_asn_file( files=files, ) success = self.run_step( asn_file, ) # If not everything has succeeded, then return a warning if not success: log.warning("Failures detected in level 3 pipeline") return False # Propagate PJPipe metadata through the output files, this is the # lv3 mosaic and the crf files all_files = [] files = glob.glob(os.path.join(self.out_dir, "*_i2d.fits", ) ) all_files.extend(files) files = glob.glob(os.path.join(self.out_dir, "*_crf.fits", ) ) all_files.extend(files) self.propagate_metadata(all_files) with open(step_complete_file, "w+") as f: f.close() return True
[docs] def create_asn_file( self, files, ): """Setup asn lv3 file""" log.info("Building asn file") check_bgr = True band_short = get_short_band_name(self.band) # If we have NIRCam/NIRISS operating with parallel offs, switch off background checking # else everything will be flagged as backgrounds if self.band_type in ["nircam", "niriss"] and self.bgr_check_type == "parallel_off": check_bgr = False asn_lv3_filename = os.path.join( self.in_dir, f"asn_lv3_{self.band}.json", ) files.sort() tab = get_obs_table( files=files, check_bgr=check_bgr, check_type=self.bgr_check_type, background_name=self.bgr_background_name, ) json_content = { "asn_type": "None", "asn_rule": "DMS_Level3_Base", "version_id": time.strftime("%Y%m%dt%H%M%S"), "code_version": jwst.__version__, "degraded_status": "No known degraded exposures in association.", "program": tab["Program"][0], "constraints": "No constraints", "asn_id": f"o{tab['Obs_ID'][0]}", "asn_pool": "none", "products": [ { "name": f"{self.target.lower()}_{self.band_type}_lv3_{band_short.lower()}{self.bgr_ext}", "members": [], } ], } # If we're only processing background, flip the switch if self.is_bgr: tab["Type"] = "sci" # If we're processing background like science, flip the switch if self.process_bgr_like_science: tab["Type"] = "sci" # Only take things flagged as science tab = tab[tab["Type"] == "sci"] for row in tab: json_content["products"][-1]["members"].append( {"expname": row["File"], "exptype": "science", "exposerr": "null"} ) with open(asn_lv3_filename, "w") as f: json.dump(json_content, f) return asn_lv3_filename
[docs] def run_step( self, asn_file, ): """Run the level 3 step Args: asn_file: Path to association JSON file """ log.info("Running level 3 pipeline") band_type, short_long = get_band_type(self.band, short_long_nircam=True) band_short = get_short_band_name(self.band) # FWHM should be set per-band for both tweakreg and source catalogue fwhm_pix = fwhms_pix[band_short] # Set up to run lv3 pipeline config = calwebb_image3.Image3Pipeline.get_config_from_reference(asn_file) im3 = calwebb_image3.Image3Pipeline.from_config_section(config) im3.output_dir = self.out_dir im3.tweakreg.kernel_fwhm = fwhm_pix # * 2 im3.source_catalog.kernel_fwhm = fwhm_pix # * 2 im3 = attribute_setter( im3, parameters=self.jwst_parameters, band=self.band, target=self.target, ) # Load the asn file in, so we have access to everything we need later asn_file = ModelContainer(asn_file) # Run the tweakreg step config = TweakRegStep.get_config_from_reference(asn_file) tweakreg = TweakRegStep.from_config_section(config) tweakreg.output_dir = self.out_dir tweakreg.save_results = False tweakreg.kernel_fwhm = fwhm_pix # * 2 try: tweakreg_params = self.jwst_parameters["tweakreg"] except KeyError: tweakreg_params = {} for tweakreg_key in tweakreg_params: value = parse_parameter_dict( parameters=tweakreg_params, key=tweakreg_key, band=self.band, target=self.target, ) if value == "VAL_NOT_FOUND": continue recursive_setattr(tweakreg, tweakreg_key, value) # Keep track of exposure numbers and group IDs in case we change them meta_params = {} for model in asn_file._models: model_name = model.meta.filename if hasattr(model.meta.observation, "exposure_number"): exp_no = copy.deepcopy(model.meta.observation.exposure_number) else: exp_no = "" if hasattr(model.meta, "group_id"): group_id = copy.deepcopy(model.meta.group_id) else: group_id = "" meta_params[model_name] = [exp_no, group_id, ] # Group up the dithers if short_long in self.tweakreg_group_dithers: for model in asn_file._models: model.meta.observation.exposure_number = "1" model.meta.group_id = "" # Or degroup the dithers elif short_long in self.tweakreg_degroup_dithers: for i, model in enumerate(asn_file._models): model.meta.observation.exposure_number = str(i) model.meta.group_id = "" # If needed, degroup the NIRCam modules. Do this by adding a large # number to the exposure number if ( band_type == "nircam" and self.tweakreg_degroup_nircam_modules ): for i, model in enumerate(asn_file._models): module = model.meta.instrument.module.strip().lower() exp_no = int(model.meta.observation.exposure_number) if module == "a": exp_add = 99 elif module == "b": exp_add = 100 else: raise ValueError("Expecting module to either be A or B") model.meta.observation.exposure_number = str(exp_no + exp_add) model.meta.group_id = "" # Degroup the 1/2/3/4 NIRCam shorts, if requested if ( band_type == "nircam" and self.tweakreg_degroup_nircam_short_chips ): for i, model in enumerate(asn_file._models): detector = model.meta.instrument.detector.strip().lower() exp_no = int(model.meta.observation.exposure_number) # Include information from the particular chip if we're in short # mode (i.e. there's a 1-4 in the detector name), and keep # track of this to modify the group ID exp_add = 0 if "1" in detector: exp_add += 49 elif "2" in detector: exp_add += 50 elif "3" in detector: exp_add += 51 elif "4" in detector: exp_add += 52 model.meta.observation.exposure_number = str(exp_no + exp_add) model.meta.group_id = "" asn_file = tweakreg.run(asn_file) # If the asn file is a ModelLibrary, we need to force the name back in use_model_library = False if isinstance(asn_file, ModelLibrary): use_model_library = True # If there's no final name here, add it now if "name" not in asn_file._asn["products"][0]: name = f"{self.target.lower()}_{self.band_type}_lv3_{band_short.lower()}{self.bgr_ext}" asn_file._asn["products"][0]["name"] = copy.deepcopy(name) del tweakreg gc.collect() # Make sure we skip tweakreg since we've already done it im3.tweakreg.skip = True if use_model_library: models = asn_file._loaded_models else: models = asn_file._models # Reset if we're degrouping NIRCam modules if ( band_type == "nircam" and self.tweakreg_degroup_nircam_modules ): if use_model_library: for i in models: model_name = models[i].meta.filename models[i].meta.observation.exposure_number = meta_params[model_name][0] models[i].meta.group_id = meta_params[model_name][1] else: for i, model in enumerate(models): model_name = model.meta.filename model.meta.observation.exposure_number = meta_params[model_name][0] model.meta.group_id = meta_params[model_name][1] # Remove the chip info if we're degrouping the NIRCam short chips if ( band_type == "nircam" and self.tweakreg_degroup_nircam_short_chips ): if use_model_library: for i in models: model_name = models[i].meta.filename models[i].meta.observation.exposure_number = meta_params[model_name][0] models[i].meta.group_id = meta_params[model_name][1] else: for i, model in enumerate(models): model_name = model.meta.filename model.meta.observation.exposure_number = meta_params[model_name][0] model.meta.group_id = meta_params[model_name][1] # Set meta parameters back to original values for group/degrouping of dithers if ( short_long in self.tweakreg_group_dithers or short_long in self.tweakreg_degroup_dithers ): if use_model_library: for i in models: model_name = models[i].meta.filename models[i].meta.observation.exposure_number = meta_params[model_name][0] models[i].meta.group_id = meta_params[model_name][1] else: for i, model in enumerate(models): model_name = model.meta.filename model.meta.observation.exposure_number = meta_params[model_name][0] model.meta.group_id = meta_params[model_name][1] # Run the skymatch step with custom hacks if required config = SkyMatchStep.get_config_from_reference(asn_file) skymatch = SkyMatchStep.from_config_section(config) skymatch.output_dir = self.out_dir skymatch.save_results = False try: skymatch_params = self.jwst_parameters["skymatch"] except KeyError: skymatch_params = {} for skymatch_key in skymatch_params: value = parse_parameter_dict( parameters=skymatch_params, key=skymatch_key, band=self.band, target=self.target, ) if value == "VAL_NOT_FOUND": continue recursive_setattr(skymatch, skymatch_key, value) # Group or degroup for skymatching if short_long in self.skymatch_group_dithers: if use_model_library: for i in models: models[i].meta.observation.exposure_number = "1" models[i].meta.group_id = "" else: for model in models: model.meta.observation.exposure_number = "1" model.meta.group_id = "" elif short_long in self.skymatch_degroup_dithers: if use_model_library: for i in models: models[i].meta.observation.exposure_number = str(i) models[i].meta.group_id = "" else: for i, model in enumerate(models): model.meta.observation.exposure_number = str(i) model.meta.group_id = "" asn_file = skymatch.run(asn_file) del skymatch gc.collect() use_model_library = False if isinstance(asn_file, ModelLibrary): use_model_library = True # If there's no final name here, add it now if "name" not in asn_file._asn["products"][0]: name = f"{self.target.lower()}_{self.band_type}_lv3_{band_short.lower()}{self.bgr_ext}" asn_file._asn["products"][0]["name"] = copy.deepcopy(name) if use_model_library: models = asn_file._loaded_models else: models = asn_file._models # Set meta parameters back to original values to avoid potential weirdness later if ( short_long in self.skymatch_group_dithers or short_long in self.skymatch_degroup_dithers ): if use_model_library: for i in models: model_name = models[i].meta.filename models[i].meta.observation.exposure_number = meta_params[model_name][0] models[i].meta.group_id = meta_params[model_name][1] else: for i, model in enumerate(models): model_name = model.meta.filename model.meta.observation.exposure_number = meta_params[model_name][0] model.meta.group_id = meta_params[model_name][1] im3.skymatch.skip = True # Run the rest of the level 3 pipeline if use_model_library: # Re-instantiate the ModelLibrary, to wipe out any weirdness we might have performed # along the way asn_file = ModelContainer([models[m] for m in models]) asn_file = ModelLibrary(asn_file, on_disk=False) if "name" not in asn_file._asn["products"][0]: name = f"{self.target.lower()}_{self.band_type}_lv3_{band_short.lower()}{self.bgr_ext}" asn_file._asn["products"][0]["name"] = copy.deepcopy(name) im3.run(asn_file) del im3 del asn_file gc.collect() # Drizzle individual frames to the common mosaic WCS if self.do_drizzle: i2d_files = glob.glob(os.path.join(self.out_dir, "*_i2d.fits")) crf_files = sorted(glob.glob(os.path.join(self.out_dir, "*_crf.fits"))) if i2d_files and crf_files: ref_wcs_file = os.path.join(self.out_dir, "resample_refwcs.asdf") with datamodels.open(i2d_files[0]) as i2d_model: asdf.AsdfFile({ "wcs": i2d_model.meta.wcs, "array_shape": i2d_model.data.shape, }).write_to(ref_wcs_file) resample_single = ResampleStep() resample_single.single = True resample_single.output_use_model = True resample_single.save_results = True resample_single.output_dir = self.out_dir resample_single.output_wcs = ref_wcs_file try: resample_params = self.jwst_parameters["resample"] except KeyError: resample_params = {} wcs_keys = { "output_wcs", "crval", "crpix", "rotation", "pixel_scale", "pixel_scale_ratio", "output_shape", } for resample_key in resample_params: if resample_key in wcs_keys: continue value = parse_parameter_dict( parameters=resample_params, key=resample_key, band=self.band, target=self.target, ) if value == "VAL_NOT_FOUND": continue recursive_setattr(resample_single, resample_key, value) resample_single.run(crf_files) del resample_single gc.collect() # Notes: # - ResampleImage.resample_many_to_many() adds _outlier_resamplestep.fits # to the file names, but it's not an outlier image, so let's rename it. # - We'd have to edit ResampleStep for a proper fix. # - Can't have suffix ending in _i2d.fits, messes with anchoring pattern matching. for f in glob.glob(os.path.join(self.out_dir, "*_outlier_resamplestep.fits")): os.rename(f, f.replace("_outlier_resamplestep.fits", "_i2d_single.fits")) else: log.warning("do_drizzle is set but no i2d/crf files found") if self.do_blot: i2d_files = glob.glob(os.path.join(self.out_dir, "*_i2d.fits")) crf_files = sorted(glob.glob(os.path.join(self.out_dir, "*_crf.fits"))) if i2d_files and crf_files: self._blot_to_detector_frame(crf_files=crf_files) else: log.warning("do_blot is set but no i2d/crf files found") return True
def _blot_to_detector_frame( self, crf_files, ): """Blot the final i2d mosaic to each exposure's detector frame. The ``*_i2d.fits`` mosaic is blotted onto the detector frame of every CRF exposure, producing one ``*_i2d_blot.fits`` file per exposure alongside the other lv3 products. Args: crf_files: Sorted list of CRF file paths (one per exposure). Each CRF provides the detector-frame GWCS and shape that the mosaic is blotted onto. """ i2d_files = glob.glob(os.path.join(self.out_dir, "*_i2d.fits")) if not i2d_files: log.warning("_blot_to_detector_frame: no *_i2d.fits mosaic found") return if len(i2d_files) > 1: log.warning( "_blot_to_detector_frame: found %d i2d files, using %s", len(i2d_files), os.path.basename(i2d_files[0]), ) log.info( "Blotting i2d mosaic to %d exposure detector frames", len(crf_files), ) with datamodels.open(i2d_files[0]) as i2d_model: mosaic_data = i2d_model.data.astype(np.float32) mosaic_wcs = i2d_model.meta.wcs for crf_file in crf_files: with datamodels.open(crf_file) as crf_model: blot_wcs = crf_model.meta.wcs blot_shape = crf_model.data.shape pixflux_area = crf_model.meta.photometry.pixelarea_steradians blot_wcs.array_shape = blot_shape pixel_area = compute_image_pixel_area(blot_wcs) pix_ratio = np.sqrt(pixflux_area / pixel_area) blotted = gwcs_blot( median_data=mosaic_data, median_wcs=mosaic_wcs, blot_shape=blot_shape, blot_wcs=blot_wcs, pix_ratio=pix_ratio, fillval=self.blot_fillval, ) out_name = crf_file.replace("_crf.fits", "_i2d_blot.fits") blot_model = datamodels.ImageModel(data=blotted) blot_model.update(crf_model) blot_model.meta.wcs = copy.deepcopy(blot_wcs) save_file(blot_model, out_name=out_name, dr_version=self.dr_version) blot_model.close() log.info( " mosaic -> %s", os.path.basename(out_name), ) gc.collect()
[docs] def propagate_metadata(self, files): """Propagate metadata through to the output files Args: files: List of files to loop over """ for out_name in files: with datamodels.open(out_name) as im: save_file(im, out_name=out_name, dr_version=self.dr_version) del im return True