Source code for pjpipe.regress_against_previous.regress_against_previous_step

import copy
import glob
import logging
import os
import warnings

import cmocean
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pypdf import PdfWriter
from reproject import reproject_interp, reproject_adaptive, reproject_exact

ALLOWED_REPROJECT_FUNCS = [
    "interp",
    "adaptive",
    "exact",
]

matplotlib.use("agg")
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 14

log = logging.getLogger(__name__)


def get_diff_image(
    filename,
    v_curr,
    v_prev,
    curr_file_ext,
    reproject_func="interp",
    percentiles=None,
    file_exts=None,
):
    """Reproject images to get a difference image

    Args:
        filename: Name of file
        v_curr: Current version
        v_prev: Previous version
        curr_file_ext: Current file extension
        reproject_func: Which reproject function to use. Defaults to 'interp',
                but can also be 'exact' or 'adaptive'
        percentiles: Percentiles for diff image. Defaults to None,
            which will be [1, 99]th percentiles
        file_exts: List of file extensions to search for the previous
            file in priority order. Defaults to None, which will go
            anchor->align->pipeline.
    """
    if percentiles is None:
        percentiles = [1, 99]
    if file_exts is None:
        file_exts = [
            "i2d_anchor.fits",
            "i2d_align.fits",
            "i2d.fits",
        ]

    with fits.open(filename) as hdu1:
        prev_file_found = False

        for file_ext in file_exts:
            if not prev_file_found:
                prev_filename = filename.replace(v_curr, v_prev)
                prev_filename = prev_filename.replace(curr_file_ext, file_ext)
                if os.path.exists(prev_filename):
                    prev_file_found = True

        if not prev_file_found:
            return None, None

        hdu1["SCI"].data[hdu1["SCI"].data == 0] = np.nan

        if reproject_func == "interp":
            r_func = reproject_interp
        elif reproject_func == "exact":
            r_func = reproject_exact
        elif reproject_func == "adaptive":
            r_func = reproject_adaptive
        else:
            raise ValueError(f"reproject_func should be one of {ALLOWED_REPROJECT_FUNCS}")

        with fits.open(prev_filename) as hdu2:
            # Reproject both HDUs, just to be sure
            hdu2["SCI"].data[hdu2["SCI"].data == 0] = np.nan
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                data1 = r_func(
                    hdu1["SCI"],
                    hdu1["SCI"].header,
                    return_footprint=False,
                )
                data2 = r_func(
                    hdu2["SCI"],
                    hdu1["SCI"].header,
                    return_footprint=False,
                )

            diff = data1 - data2

            v = np.nanmax(np.abs(np.nanpercentile(diff, percentiles)))

    return diff, v


[docs] class RegressAgainstPreviousStep: def __init__( self, target, in_dir, curr_version, prev_version=None, file_exts=None, reproject_func="interp", overwrite=False, ): """Create diagnostic plots to regress against previous versions Args: target: Target to consider in_dir: Input directory curr_version: Current version to compare to... prev_version: Previous version file_exts: File extensions (in priority order) to search for reproject_func: Which reproject function to use. Defaults to 'interp', but can also be 'exact' or 'adaptive' overwrite: Whether to overwrite or not. Defaults to False """ if prev_version is None: raise ValueError("prev_version should be defined") if reproject_func not in ALLOWED_REPROJECT_FUNCS: raise ValueError(f"reproject_func should be one of {ALLOWED_REPROJECT_FUNCS}") if file_exts is None: file_exts = [ "i2d_anchor.fits", "i2d_align.fits", "i2d.fits", ] self.target = target self.in_dir = in_dir self.curr_version = curr_version self.prev_version = prev_version self.file_exts = file_exts self.reproject_func = reproject_func self.overwrite = overwrite self.out_dir = os.path.join( self.in_dir, f"{self.curr_version}_to_{self.prev_version}", )
[docs] def do_step(self): """Run previous version regression""" step_complete_file = os.path.join( self.out_dir, f"{self.target}_regress_against_previous_step_complete.txt", ) if self.overwrite and os.path.exists(step_complete_file): os.remove(step_complete_file) if os.path.exists(step_complete_file): log.info("Step already run") return True if not os.path.exists(self.out_dir): os.makedirs(self.out_dir) # Get list of appropriate best files all_files = [] all_file_exts = [] for file_ext in self.file_exts: # Get all the files that match files = glob.glob( os.path.join(self.in_dir, self.target, f"*_{file_ext}"), ) # If they're already in the list, don't add them again filtered_files = [] for file in files: file_already_in_list = False file_short = os.path.split(file)[-1] file_short = "_".join(file_short.split("_")[:-2]) for list_file in all_files: if not file_already_in_list: if file_short in list_file: file_already_in_list = True if not file_already_in_list: filtered_files.append(file) # Add to the final list all_files.extend(filtered_files) all_file_exts.extend([file_ext] * len(filtered_files)) file_dict = {} for key in [ "nircam", "niriss", "miri", ]: idx = [ i for i in range(len(all_files)) if key in os.path.split(all_files[i])[-1] ] files = [all_files[i] for i in idx] file_exts = [all_file_exts[i] for i in idx] sort_idx = np.argsort(files) files = np.asarray(files)[sort_idx] file_exts = np.asarray(file_exts)[sort_idx] file_dict[key] = { "files": files, "file_exts": file_exts, } for key in file_dict: success = self.regress_plot( file_dict=file_dict, key=key, ) # If not everything has succeeded, then return a warning if not success: log.warning("Failures detected in previous version regression") return False # Merge these all into a single pdf doc merged_filename = os.path.join( self.out_dir, f"{self.curr_version}_to_{self.prev_version}_comparisons_merged.pdf", ) pdfs = glob.glob( os.path.join( self.out_dir, "*_comparison.pdf", ) ) pdfs.sort() with PdfWriter() as merger: for pdf in pdfs: merger.append(pdf) merger.write(merged_filename) merger.close() with open(step_complete_file, "w+") as f: f.close() return True
[docs] def regress_plot(self, file_dict, key): """Plot per-instrument comparison Args: file_dict: Dictionary of files, separated by instrument key: Instrument key """ fancy_name = { "miri": "MIRI", "nircam": "NIRCam", "niriss": "NIRISS", }[key] log.info(f"Plotting up {key}") files = copy.deepcopy(file_dict[key]["files"]) file_exts = copy.deepcopy(file_dict[key]["file_exts"]) if len(files) > 0: # Figure out if we've got backgrounds, and if so how many backgrounds = np.array(["_bgr_" in file for file in files], dtype=bool) n_backgrounds, n_sci = np.sum(backgrounds), np.sum(~backgrounds) has_backgrounds = n_backgrounds > 0 plot_name = os.path.join(self.out_dir, f"{self.target}_{key}_comparison") if has_backgrounds: nrows = 2 ncols = np.max([n_backgrounds, n_sci]) else: nrows = 1 ncols = len(files) plt.subplots(nrows=nrows, ncols=ncols, figsize=(4 * ncols, 4 * nrows)) # Set it up so that we put science on the first row, backgrounds on the second row bgr_offset = copy.deepcopy(ncols) + 1 sci_offset = 1 for i, file in enumerate(files): file_short = os.path.split(file)[-1] # Make sure we get bands right if it's a background obs if "_bgr_" in file_short: is_bgr = True else: is_bgr = False band = file_short.split("_")[3] file_ext = file_exts[i] diff, v = get_diff_image( file, v_curr=self.curr_version, v_prev=self.prev_version, curr_file_ext=file_ext, file_exts=self.file_exts, reproject_func=self.reproject_func, ) if not is_bgr: ax = plt.subplot(nrows, ncols, sci_offset) sci_offset += 1 else: ax = plt.subplot(nrows, ncols, bgr_offset) bgr_offset += 1 if diff is None: plt.text( 0.5, 0.5, f"Not present in {self.prev_version}", ha="center", va="center", fontweight="bold", bbox=dict(fc="white", ec="black", alpha=0.9), transform=ax.transAxes, ) plt.axis("off") else: vmin, vmax = -v, v im = ax.imshow( diff, vmin=vmin, vmax=vmax, cmap=cmocean.cm.balance, origin="lower", interpolation="nearest", ) plt.axis("off") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0) plt.colorbar(im, cax=cax, label="MJy/sr") band_text = band.upper() if is_bgr: band_text += " bgr" plt.text( 0.05, 0.95, band_text, ha="left", va="top", fontweight="bold", bbox=dict(fc="white", ec="black", alpha=0.9), transform=ax.transAxes, ) plt.suptitle(f"{self.target.upper()}, {fancy_name}") plt.tight_layout() plt.savefig(f"{plot_name}.png", bbox_inches="tight") plt.savefig(f"{plot_name}.pdf", bbox_inches="tight") plt.close() return True