Source code for pjpipe.apply_wcs_adjust.apply_wcs_adjust_step

import copy
import gc
import glob
import logging
import multiprocessing as mp
import os
import re
import shutil
import warnings
from fnmatch import fnmatch
from functools import partial

import numpy as np
from jwst.assign_wcs.util import update_fits_wcsinfo
from stdatamodels.jwst import datamodels
from tqdm import tqdm
from tweakwcs.correctors import JWSTWCSCorrector

from ..utils import band_exts

log = logging.getLogger(__name__)


[docs] class ApplyWCSAdjustStep: def __init__( self, wcs_adjust, in_dir, out_dir, step_ext, procs, overwrite=False, ): """Apply WCS adjustments to images Args: wcs_adjust: Dictionary for WCS adjustments in_dir: Input directory out_dir: Output directory step_ext: .fits extension for the files going into the step procs: Number of processes to run in parallel overwrite: Whether to overwrite or not. Defaults to False """ self.wcs_adjust = wcs_adjust self.in_dir = in_dir self.out_dir = out_dir self.step_ext = step_ext self.procs = procs self.overwrite = overwrite
[docs] def do_step(self): """Run applying the WCS adjustments""" if self.overwrite: shutil.rmtree(self.out_dir) if not os.path.exists(self.out_dir): os.makedirs(self.out_dir) # Check if we've already run the step step_complete_file = os.path.join( self.out_dir, "apply_wcs_adjust_step_complete.txt", ) if os.path.exists(step_complete_file): log.info("Step already run") return True files = glob.glob( os.path.join( self.in_dir, f"*_{self.step_ext}.fits", ) ) files.sort() # Ensure we're not wasting processes procs = np.nanmin([self.procs, len(files)]) successes = self.run_step( files, procs=procs, ) if not np.all(successes): log.warning("Failures detected in applying WCS adjustments") return False with open(step_complete_file, "w+") as f: f.close() return True
[docs] def run_step( self, files, procs=1, ): """Wrap paralellism around applying WCS adjusts Args: files: List of files to mask lyot in procs: Number of parallel processes to run. Defaults to 1 """ log.info(f"Applying WCS corrections") with mp.get_context("fork").Pool(procs) as pool: successes = [] for success in tqdm( pool.imap_unordered( partial( self.parallel_wcs_adjust, ), files, ), ascii=True, desc="Applying WCS corrections", total=len(files), ): successes.append(success) pool.close() pool.join() gc.collect() return successes
[docs] def parallel_wcs_adjust( self, file, ): """Parallelise applying WCS adjustments Args: file: File to apply WCS corrections to """ file_short = os.path.split(file)[-1] output_file = os.path.join( self.out_dir, file_short, ) # Set up the WCSCorrector per tweakreg with datamodels.open(file) as input_im: model_name = os.path.splitext(input_im.meta.filename)[0].strip('_- ') refang = input_im.meta.wcsinfo.instance im = JWSTWCSCorrector( wcs=input_im.meta.wcs, wcsinfo={'roll_ref': refang['roll_ref'], 'v2_ref': refang['v2_ref'], 'v3_ref': refang['v3_ref']}, meta={'image_model': input_im, 'name': model_name}, ) # Check if we're NIRCam is_nircam = fnmatch(file_short, f"*{band_exts['nircam']}") # Pull out the info we need to shift. If we have both # dithers ungrouped and grouped, prefer the ungrouped # ones visit_grouped = file_short.split("_")[0] visit_ungrouped = "_".join(file_short.split("_")[:3]) matrix = [[1, 0], [0, 1]] shift = [0, 0] visit_found = False for visit in [visit_ungrouped, visit_grouped]: if not visit_found: adjust_found = False for adjust in self.wcs_adjust["wcs_adjust"]: if adjust_found: continue # If we have a degrouped NIRCam module adjust, then edit this to # look more like the file name adjust_is_nircam_degrouped = fnmatch(adjust, f"*{band_exts['nircam']}") adjust_edit = copy.deepcopy(adjust) # If we've got a degrouped adjust, then split off this bit adjust_split = adjust.split("_") if adjust_is_nircam_degrouped: adjust_edit = "_".join(adjust_split[:-1]) # Now, check this against the visit, and if this doesn't match, # continue if not adjust_edit == visit: continue # If they're both NIRCam, then we care about the # particular module if adjust_is_nircam_degrouped and is_nircam: adjust_module = re.findall("nrc([ab])", adjust_split[-1] ) if len(adjust_module) > 0: adjust_module = adjust_module[0] file_module = re.findall("nrc([ab])", file_short ) if len(file_module) > 0: file_module = file_module[0] if not adjust_module == file_module: continue wcs_adjust_vals = self.wcs_adjust["wcs_adjust"][adjust] try: matrix = wcs_adjust_vals["matrix"] except KeyError: matrix = [[1, 0], [0, 1]] try: shift = wcs_adjust_vals["shift"] except KeyError: shift = [0, 0] adjust_found = True visit_found = True if not visit_found: log.info(f"No shifts found for {file_short}. Will write out without shifting") if visit_found: im.set_correction(matrix=matrix, shift=shift) image_model = im.meta["image_model"] image_model.meta.wcs = im.wcs with warnings.catch_warnings(): warnings.simplefilter("ignore") try: update_fits_wcsinfo( image_model, ) except (ValueError, RuntimeError) as e: log.warning( "Failed to update 'meta.wcsinfo' with FITS SIP " f"approximation. Reported error is:\n'{e.args[0]}'" ) else: image_model = copy.deepcopy(im.meta["image_model"]) image_model.save(output_file) del input_im del image_model del im gc.collect() return True