Source code for pjpipe.lv1.lv1_step

import copy
import gc
import glob
import logging
import multiprocessing as mp
import os
import shutil
from functools import partial

import numpy as np
from jwst.pipeline import calwebb_detector1
from stdatamodels.jwst import datamodels

from ..utils import attribute_setter, save_file

log = logging.getLogger("stpipe")
log.addHandler(logging.NullHandler())


[docs] class Lv1Step: def __init__( self, target, band, in_dir, out_dir, dr_version, step_ext, procs, jwst_parameters=None, overwrite=False, ): """Wrapper around the level 1 JWST pipeline Args: target: Target to consider band: Band to consider in_dir: Input directory out_dir: Output directory dr_version: Data processing version step_ext: .fits extension for the files going into the step procs: Number of processes to run in parallel. jwst_parameters: Parameter dictionary to pass to the level 1 pipeline. Defaults to None, which will run the observatory defaults overwrite: Whether to overwrite or not. Defaults to False """ if jwst_parameters is None: jwst_parameters = {} self.target = target self.band = band self.in_dir = in_dir self.out_dir = out_dir self.dr_version = dr_version self.step_ext = step_ext self.procs = procs self.jwst_parameters = jwst_parameters self.overwrite = overwrite
[docs] def do_step(self): """Run the level 1 pipeline""" if self.overwrite: shutil.rmtree(self.out_dir) if not os.path.exists(self.out_dir): os.makedirs(self.out_dir) # Check if we've already run the step step_complete_file = os.path.join( self.out_dir, "lv1_step_complete.txt", ) if os.path.exists(step_complete_file): log.info("Step already run") return True # We need to operate this in the input directory cwd = os.getcwd() os.chdir(self.in_dir) # Build file list in_files = glob.glob(f"*_{self.step_ext}.fits") if len(in_files) == 0: log.warning(f"No {self.step_ext} files found") os.chdir(cwd) return False in_files.sort() # For speed, we want to parallelise these up by dither since we use the # persistence file dithers = [] for file in in_files: file_split = os.path.split(file)[-1].split("_") dithers.append("_".join(file_split[:2]) + "_*_" + file_split[-2]) dithers = np.unique(dithers) dithers.sort() # Ensure we're not wasting processes procs = np.nanmin([self.procs, len(dithers)]) successes = self.run_step( dithers, procs=procs, ) # If not everything has succeeded, then return a warning if not np.all(successes): log.warning("Failures detected in level 1 pipeline") os.chdir(cwd) return False os.chdir(cwd) with open(step_complete_file, "w+") as f: f.close() return True
[docs] def run_step( self, dithers, procs=1, ): """Wrap parallelism around the level 1 pipeline Args: dithers: List of dithers to loop over procs: Number of processes to run. Defaults to 1 """ log.info("Running level 1 pipeline") # Ensure we pre-cache references, to avoid errors in multiprocessing. Loop over # all just to be safe for dither in dithers: uncal_files = glob.glob(f"{dither}*_{self.step_ext}.fits") for uncal_file in uncal_files: config = calwebb_detector1.Detector1Pipeline.get_config_from_reference( uncal_file ) detector1 = calwebb_detector1.Detector1Pipeline.from_config_section( config ) detector1._precache_references(uncal_file) with mp.get_context("fork").Pool(procs) as pool: successes = [] for success in pool.imap_unordered( partial( self.parallel_lv1, ), dithers, ): successes.append(success) pool.close() pool.join() gc.collect() return successes
[docs] def parallel_lv1( self, dither, ): """Parallelise lv1 reprocessing Args: dither: Name for dither group. This is used because we inherit persistence from previous integration in the set """ uncal_files = glob.glob(f"{dither}*_{self.step_ext}.fits") uncal_files.sort() for uncal_file in uncal_files: config = calwebb_detector1.Detector1Pipeline.get_config_from_reference( uncal_file ) detector1 = calwebb_detector1.Detector1Pipeline.from_config_section(config) # Pull out the trapsfilled file from preceding exposure persist_file = "" uncal_file_split = uncal_file.split("_") exposure_str = uncal_file_split[2] prev_exposure_int = int(exposure_str) - 1 if prev_exposure_int > 0: prev_exposure_str = f"{prev_exposure_int:05}" persist_file = copy.deepcopy(uncal_file_split) persist_file[2] = prev_exposure_str persist_file[-1] = "trapsfilled.fits" persist_file = os.path.join(self.out_dir, "_".join(persist_file)) # Specify the name of the trapsfilled file detector1.persistence.input_trapsfilled = persist_file # Set other parameters detector1.output_dir = self.out_dir detector1 = attribute_setter( detector1, parameters=self.jwst_parameters, band=self.band, target=self.target, ) # Run the level 1 pipeline detector1.run(uncal_file) del detector1 # Since running these steps seems to destroy the history parameter, # add this back in out_name = os.path.join(self.out_dir, uncal_file.replace(f"{self.step_ext}.fits", "rate.fits"), ) with datamodels.open(out_name) as im: save_file(im, out_name=out_name, dr_version=self.dr_version) del im gc.collect() return True