Source code for pjpipe.get_wcs_adjust.get_wcs_adjust_step

import copy
import glob
import logging
import os
import shutil

import numpy as np
from astropy.table import Table, QTable
from jwst.datamodels import ModelContainer
from jwst.tweakreg import TweakRegStep
from stdatamodels.jwst import datamodels

from ..utils import get_band_type, fwhms_pix, parse_parameter_dict, recursive_setattr

log = logging.getLogger("stpipe")
log.addHandler(logging.NullHandler())

RAD_TO_ARCSEC = 3600 * np.rad2deg(1)


def write_visit_transforms(
        visit_transforms,
        out_file,
):
    """Write out table of WCS transforms

    Args:
        visit_transforms: Dictionary of transforms
            per visit
        out_file: Name for the output .toml file
    """
    log.info(f"Writing transforms")

    with open(out_file, "w+") as f:
        f.write("[wcs_adjust]\n\n")

        # Skip where we don't have anything
        if len(visit_transforms) == 0:
            log.info("No WCS adjusts found. Skipping")
            f.close()
            return True

        for visit in visit_transforms:
            # If we only have one shift value, take that, otherwise take the mean
            if len(visit_transforms[visit]["shift"].shape) == 1:
                shift = visit_transforms[visit]["shift"]
            else:
                shift = np.nanmean(visit_transforms[visit]["shift"], axis=0)

            # If we only have one matrix value, take that, otherwise take the mean
            if len(visit_transforms[visit]["matrix"].shape) == 2:
                matrix = visit_transforms[visit]["matrix"]
            else:
                matrix = np.nanmean(visit_transforms[visit]["matrix"], axis=-1)

            # Format these as nice strings and write out
            shift_str = [float(f"{s:.3f}") for s in shift]
            matrix_l1 = [float(f"{s:.3f}") for s in matrix[0]]
            matrix_l2 = [float(f"{s:.3f}") for s in matrix[1]]

            f.write(f"{visit}.shift = {shift_str}\n")
            f.write(f"{visit}.matrix = [\n\t{matrix_l1},\n\t{matrix_l2}\n]\n")

        f.write("\n")
        f.close()

    return True


[docs] class GetWCSAdjustStep: def __init__( self, directory, progress_dict, target, alignment_dir, bands=None, alignment_catalogs=None, group_dithers=None, tweakreg_parameters=None, overwrite=False, ): """Gets a table of WCS corrections to apply to visit groups Experience has shown that the relative JWST guide star uncertainty is very small, but there are significant absolute corrections between guide stars. Thus, we can use the same visit as a correction for all visits, for example using F770W/F1000W at F2100W where tweakreg doesn't work so well. Here, we take some template bands and loop over with tweakreg, writing out a table of shifts/matrices to apply to other bands. For multiple dithers etc., will take an average correction Args: directory: Directory of target progress_dict: The progress dictionary the pipeline builds up. This is used to figure out what subdirectories we should be looking in target: Target to consider alignment_dir: Directory for alignment catalogs bands: List of target bands to pull corrections out for alignment_catalogs: Dictionary mapping targets to alignment catalogs group_dithers: Which band type (e.g. nircam) to group up dithers for and find a single correction. Defaults to None, which won't group up anything tweakreg_parameters: Dictionary of parameters to pass to tweakreg. Defaults to None, which will use observatory defaults overwrite: Whether to overwrite or not. Defaults to False """ if bands is None: raise ValueError("Need some bands to get WCS adjustments") if group_dithers is None: group_dithers = [] if tweakreg_parameters is None: tweakreg_parameters = {} if alignment_catalogs is None: alignment_catalogs = {} self.directory = directory self.progress_dict = progress_dict self.target = target self.alignment_dir = alignment_dir self.bands = bands self.alignment_catalogs = alignment_catalogs self.group_dithers = group_dithers self.tweakreg_parameters = tweakreg_parameters self.overwrite = overwrite
[docs] def do_step(self): """Run the WCS adjust step""" step_complete_file = os.path.join( self.directory, "get_wcs_adjust_step_complete.txt", ) out_file = os.path.join(self.directory, f"{self.target}_wcs_adjust.toml") if self.overwrite: if os.path.exists(out_file): os.remove(out_file) if os.path.exists(step_complete_file): os.remove(step_complete_file) if os.path.exists(step_complete_file): log.info("Step already run") return True # Get transforms visit_transforms = self.get_visit_transforms() # Write transforms success = write_visit_transforms( visit_transforms, out_file, ) if not success: log.warning("Failures detected in getting WCS adjustments") return False with open(step_complete_file, "w+") as f: f.close() return True
[docs] def get_visit_transforms(self): """Get transforms per-visit, running tweakreg and pulling out corrections""" in_ext = "cal" out_ext = "wcs_adjust" visit_transforms = {} out_dir = os.path.join(self.directory, "get_wcs_adjust") if not os.path.exists(out_dir): os.makedirs(out_dir) log.info(f"Getting transforms") for band_full in self.bands: if "bgr" in band_full: band = band_full.replace("_bgr", "") else: band = copy.deepcopy(band_full) band_type = get_band_type(band) # Some various failure states if band_full not in self.progress_dict: log.warning(f"No data found for {band_full}. Skipping") continue if "dir" not in self.progress_dict[band_full]: log.warning(f"No files found for {band_full}. Skipping") continue if not self.progress_dict[band_full]["success"]: log.warning(f"Previous failures found for {band_full}. Skipping") continue band_dir = copy.deepcopy(self.progress_dict[band_full]["dir"]) if not os.path.exists(band_dir): log.warning(f"Directory {band_dir} does not exist") continue fwhm_pix = fwhms_pix[band] in_files = glob.glob( os.path.join( band_dir, f"*_{in_ext}.fits", ) ) in_files.sort() input_models = [datamodels.open(in_file) for in_file in in_files] asn_file = ModelContainer(input_models) # Group up the dithers if band_type in self.group_dithers: for model in asn_file._models: model.meta.observation.exposure_number = "1" model.meta.group_id = "" # If we only have one group, this won't do anything so just skip if len(asn_file.models_grouped) == 1 and self.target not in self.alignment_catalogs: log.info(f"Only one group and no absolute alignment happening. Skipping") del input_models, asn_file continue tweakreg_config = TweakRegStep.get_config_from_reference(asn_file) tweakreg = TweakRegStep.from_config_section(tweakreg_config) tweakreg.output_dir = out_dir tweakreg.save_results = True tweakreg.suffix = out_ext tweakreg.kernel_fwhm = fwhm_pix * 2 # Sort this into a format that tweakreg is happy with if self.target in self.alignment_catalogs: abs_ref_catalog = os.path.join(self.directory, f"{self.target}_ref_catalog.fits", ) if not os.path.exists(abs_ref_catalog): in_catalog = os.path.join(self.alignment_dir, self.alignment_catalogs[self.target], ) align_table = QTable.read(in_catalog, format="fits") abs_tab = Table() abs_tab["RA"] = align_table["ra"] abs_tab["DEC"] = align_table["dec"] abs_tab.write(abs_ref_catalog, overwrite=True) tweakreg.abs_refcat = abs_ref_catalog for tweakreg_key in self.tweakreg_parameters: value = parse_parameter_dict( self.tweakreg_parameters, tweakreg_key, band, self.target, ) if value == "VAL_NOT_FOUND": continue recursive_setattr(tweakreg, tweakreg_key, value) tweakreg.run(asn_file) del input_models, asn_file output_files = glob.glob( os.path.join( out_dir, f"*_{out_ext}.fits", ) ) for output_file in output_files: # Get matrix and (x, y) shifts from the output file, if they exist with datamodels.open(output_file) as aligned_model: try: transform = aligned_model.meta.wcs.forward_transform["tp_affine"] matrix = transform.matrix.value xy_shift = RAD_TO_ARCSEC * transform.translation.value # Pull out a visit name. This will be different if the band is having # dithers grouped or not out_split = os.path.split(output_file)[-1] band_type = aligned_model.meta.instrument.name.strip().lower() if band_type in self.group_dithers: visit = out_split.split("_")[0] else: visit = "_".join(out_split.split("_")[:3]) if visit in visit_transforms: visit_transforms[visit]["shift"] = np.vstack( (visit_transforms[visit]["shift"], xy_shift) ) visit_transforms[visit]["matrix"] = np.dstack( (visit_transforms[visit]["matrix"], matrix) ) else: visit_transforms[visit] = {} visit_transforms[visit]["shift"] = copy.deepcopy(xy_shift) visit_transforms[visit]["matrix"] = copy.deepcopy(matrix) except IndexError: pass del aligned_model # Remove the temp directory shutil.rmtree(out_dir) # Sort the dictionary so the file is more human-readable visit_transforms = dict(sorted(visit_transforms.items())) return visit_transforms