import copy
import gc
import glob
import logging
import multiprocessing as mp
import os
import shutil
import warnings
from functools import partial
import astropy.units as u
import crds
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from astropy.stats import sigma_clipped_stats
from astropy.wcs import WCS
from jwst.flatfield.flat_field import do_correction
from mpl_toolkits.axes_grid1 import make_axes_locatable
from reproject import reproject_interp, reproject_adaptive, reproject_exact
from reproject.mosaicking import find_optimal_celestial_wcs
from reproject.mosaicking.background import determine_offset_matrix, solve_corrections_sgd
from scipy.ndimage import uniform_filter, median_filter
from stdatamodels.jwst import datamodels
from tqdm import tqdm
from ..utils import get_dq_bit_mask, make_source_mask, reproject_image, level_data
matplotlib.use("agg")
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 14
log = logging.getLogger("stpipe")
log.addHandler(logging.NullHandler())
ALLOWED_WEIGHT_METHODS = [
"mean",
"median",
"sigma_clip",
]
ALLOWED_WEIGHT_TYPES = [
"exptime",
"ivm",
"equal",
]
ALLOWED_REPROJECT_FUNCS = [
"interp",
"adaptive",
"exact",
]
def get_rotation_angle(wcs):
"""Get rotation from a WCS instance
Args:
wcs: WCS instance
"""
pc = np.dot(np.diag(wcs.wcs.get_cdelt()), wcs.wcs.get_pc())
north = np.arctan2(-pc[0, 1],
pc[0, 0],
)
angle = (north * u.rad).to(u.deg).value
return angle
def make_diagnostic_plot(
plot_name,
data,
stripes,
figsize=(9, 4),
):
"""Make a diagnostic plot to show the destriping
Args:
plot_name: Output name for plot
data: Original data
stripes: Stripe model
figsize: Size for the figure. Defaults to (9, 4)
"""
plt.figure(figsize=figsize)
n_rows = 1
n_cols = 3
vmin_data, vmax_data = np.nanpercentile(data, [10, 90])
vmin_stripes, vmax_stripes = np.nanpercentile(stripes, [1, 99])
# Plot the uncorrected data
ax = plt.subplot(n_rows, n_cols, 1)
im = plt.imshow(
data,
origin="lower",
interpolation="nearest",
vmin=vmin_data,
vmax=vmax_data,
)
plt.axis("off")
plt.text(
0.05,
0.95,
"Original",
ha="left",
va="top",
fontweight="bold",
bbox=dict(facecolor="white", edgecolor="black", alpha=0.7),
transform=ax.transAxes,
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("bottom", size="5%", pad=0)
plt.colorbar(im, cax=cax, label="MJy/sr", orientation="horizontal")
# Plot the stripes model
ax = plt.subplot(n_rows, n_cols, 2)
im = plt.imshow(
stripes,
origin="lower",
interpolation="nearest",
vmin=vmin_stripes,
vmax=vmax_stripes,
)
plt.axis("off")
plt.text(
0.05,
0.95,
"Noise model",
ha="left",
va="top",
fontweight="bold",
bbox=dict(facecolor="white", edgecolor="black", alpha=0.7),
transform=ax.transAxes,
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("bottom", size="5%", pad=0)
plt.colorbar(im, cax=cax, label="MJy/sr", orientation="horizontal")
# And finally, the corrected data
ax = plt.subplot(n_rows, n_cols, 3)
im = plt.imshow(
data - stripes,
origin="lower",
interpolation="nearest",
vmin=vmin_data,
vmax=vmax_data,
)
plt.axis("off")
plt.text(
0.05,
0.95,
"Destriped",
ha="left",
va="top",
fontweight="bold",
bbox=dict(facecolor="white", edgecolor="black", alpha=0.7),
transform=ax.transAxes,
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("bottom", size="5%", pad=0)
plt.colorbar(im, cax=cax, label="MJy/sr", orientation="horizontal")
plt.subplots_adjust(wspace=0.01)
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
plt.savefig(f"{plot_name}.pdf", bbox_inches="tight")
plt.close()
return True
def parallel_reproject_weight(
idx,
files,
optimal_wcs,
optimal_shape,
weight_type="exptime",
do_level_data=True,
reproject_func="interp",
):
"""Function to parallelise reprojecting with associated weights
Args:
idx: File idx to reproject
files: Full stack of files
optimal_wcs: Optimal WCS for image stack
optimal_shape: Optimal shape for image shape
weight_type: How to weight the average image. Defaults
to exptime, the exposure time
do_level_data: Whether to level data or not. Defaults to True
reproject_func: Which reproject function to use. Defaults to 'interp',
but can also be 'exact' or 'adaptive'
"""
file = files[idx]
data_array = reproject_image(
file,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
do_level_data=do_level_data,
reproject_func=reproject_func,
)
# Set any bad data to 0
data_array.array[np.isnan(data_array.array)] = 0
if weight_type == "exptime":
with datamodels.open(file) as model:
# Create array of exposure time
exptime = model.meta.exposure.exposure_time
del model
weight_array = copy.deepcopy(data_array)
weight_array.array[np.isfinite(weight_array.array)] = exptime
weight_array.array[data_array.array == 0] = 0
elif weight_type == "ivm":
# Reproject the VAR_RNOISE array and take inverse
weight_array = reproject_image(
file,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
hdu_type="var_rnoise",
reproject_func=reproject_func
)
weight_array.array = weight_array.array ** -1
weight_array.array[np.isnan(weight_array.array)] = 0
elif weight_type == "equal":
weight_array = copy.deepcopy(data_array)
weight_array.array[np.isfinite(weight_array.array)] = 1
weight_array.array[data_array.array == 0] = 0
else:
raise ValueError(f"weight_type should be one of {ALLOWED_WEIGHT_TYPES}")
return idx, (file, data_array, weight_array)
[docs]
class MultiTileDestripeStep:
def __init__(
self,
in_dir,
out_dir,
step_ext,
procs,
apply_to_unflat=False,
do_convergence=False,
convergence_sigma=1,
convergence_max_iterations=5,
weight_method="mean",
weight_type="ivm",
do_level_match=False,
quadrants=True,
min_mask_frac=0.2,
do_vertical_subtraction=False,
do_large_scale=True,
large_scale_filter_scale=None,
large_scale_filter_extend_mode="reflect",
sigma=3,
dilate_size=7,
maxiters=None,
reproject_func="interp",
overwrite=False,
):
"""Subtracts large-scale stripes using dither information
Create a weighted average image, then do a sigma-clipped median along (optionally)
columns and rows (optionally by quadrants), after optionally smoothing the stacked
image to attempt to remove persistent large-scale ripples.
If you see clear oversubtraction in the data, you should set do_large_scale to False.
In most cases, it appears to work well but there may be some edge cases where it doesn't
work well.
Args:
in_dir: Input directory
out_dir: Output directory
step_ext: .fits extension for the files going
into the step
procs: Number of processes to run in parallel
apply_to_unflat: If True, will undo the flat-fielding
before applying the stripe model, and then
reapply it. Defaults to False
do_convergence: Whether to loop this iteratively
until convergence, or just do a single run.
Defaults to False
convergence_sigma: Maximum sigma difference to decide
if the iterative loop has converged. Defaults to 1
convergence_max_iterations: Maximum number of iterations
to run. Defaults to 5
weight_type: Weighting method for stacking the image.
Should be one of 'mean', 'median', 'sigma_clip'. Defaults
to 'mean'
weight_type: How to weight the stacked image.
Defaults to 'ivm', inverse readnoise
do_level_match: Whether to do a simple match between tiles. Should be set
to False if this is run after level_match_step. Defaults to False
quadrants: Whether to split up stripes per-amplifier. Defaults to True
min_mask_frac: Minimum fraction of unmasked data in quadrants to calculate a median.
Defaults to 0.2 (i.e. 20% unmasked)
do_vertical_subtraction: Whether to also do a step of vertical stripe
subtraction. Defaults to False
do_large_scale: Whether to do filtering to try and remove large,
consistent ripples between data. Defaults to True
large_scale_filter_scale: Factor by which we smooth for large scale persistent
ripple removal. Defaults to None, which will use a scale ~10% of the data shape
large_scale_filter_extend_mode: How to extend values in the filter beyond
array edge. Default is "reflect". See the specific docs for more info
sigma: sigma value for sigma-clipped statistics. Defaults to 3
dilate_size: Dilation size for mask creation. Defaults to 7
maxiters: Maximum number of sigma-clipping iterations. Defaults to None
overwrite: Whether to overwrite or not. Defaults
to False
"""
if weight_method not in ALLOWED_WEIGHT_METHODS:
raise ValueError(
f"weight_method should be one of {ALLOWED_WEIGHT_METHODS}, not {weight_method}"
)
if weight_type not in ALLOWED_WEIGHT_TYPES:
raise ValueError(
f"weight_type should be one of {ALLOWED_WEIGHT_TYPES}, not {weight_type}"
)
if reproject_func not in ALLOWED_REPROJECT_FUNCS:
raise ValueError(f"reproject_func should be one of {ALLOWED_REPROJECT_FUNCS}")
if weight_method in ["median", "sigma_clip"]:
log.info(f"Using {weight_method} for creating average image, will not use weighting")
self.in_dir = in_dir
self.out_dir = out_dir
self.plot_dir = os.path.join(
self.out_dir,
"plots",
)
self.step_ext = step_ext
self.procs = procs
self.apply_to_unflat = apply_to_unflat
self.do_convergence = do_convergence
self.convergence_sigma = convergence_sigma
self.convergence_max_iterations = convergence_max_iterations
self.weight_method = weight_method
self.weight_type = weight_type
self.do_level_match = do_level_match
self.quadrants = quadrants
self.min_mask_frac = min_mask_frac
self.do_large_scale = do_large_scale
self.do_vertical_subtraction = do_vertical_subtraction
self.large_scale_filter_scale = large_scale_filter_scale
self.large_scale_filter_extend_mode = large_scale_filter_extend_mode
self.sigma = sigma
self.dilate_size = dilate_size
self.maxiters = maxiters
self.reproject_func = reproject_func
self.overwrite = overwrite
self.files_reproj = None
self.data_avg = None
self.data_avg_smooth = None
self.data_avg_mask = None
self.optimal_wcs = None
self.optimal_shape = None
[docs]
def do_step(self):
"""Run multi-tile destriping"""
if self.overwrite:
shutil.rmtree(self.out_dir)
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)
if not os.path.exists(self.plot_dir):
os.makedirs(self.plot_dir)
# Check if we've already run the step
step_complete_file = os.path.join(
self.out_dir,
"multi_tile_destripe_step_complete.txt",
)
if os.path.exists(step_complete_file):
log.info("Step already run")
return True
files = glob.glob(
os.path.join(
self.in_dir,
f"*_{self.step_ext}.fits",
)
)
files.sort()
# files = files[:16:4]
# Ensure we're not wasting processes
procs = np.nanmin([self.procs, len(files)])
# Get out the optimal WCS, since we only need to calculate this once
with warnings.catch_warnings():
warnings.simplefilter("ignore")
optimal_wcs, optimal_shape = find_optimal_celestial_wcs(files,
hdu_in="SCI",
auto_rotate=True,
)
self.optimal_wcs = optimal_wcs
self.optimal_shape = optimal_shape
converged = False
iteration = 1
while not converged:
if self.do_convergence:
log.info(f"Performing iteration {iteration}")
# Create weighted images
success = self.weighted_reproject_image(
files,
procs=procs,
do_large_scale=False,
)
if not success:
log.warning("Error in creating reproject stack")
return False
stripe_sigma = self.run_multi_tile_destripe(
procs=procs,
iteration=iteration,
do_large_scale=False,
)
# Use the output files as potential further input
files = glob.glob(
os.path.join(
self.out_dir,
f"*_{self.step_ext}.fits",
)
)
files.sort()
# If doing large-scale we repeat, but using the output files
if self.do_large_scale:
log.info("Now doing large-scale smoothing for destriping")
# Create weighted images
success = self.weighted_reproject_image(
files,
procs=procs,
do_large_scale=True,
)
if not success:
log.warning("Error in creating reproject stack")
return False
stripe_sigma = self.run_multi_tile_destripe(
procs=procs,
iteration=iteration,
do_large_scale=True,
)
# If we're not iterating, then say we've converged
if not self.do_convergence:
converged = True
else:
if not np.all(stripe_sigma < self.convergence_sigma):
if iteration < self.convergence_max_iterations:
log.info("Destriping not converged. Continuing")
else:
log.info(
"Destriping not converged but max iterations reached. Final stripe sigma values:"
)
for i, file in enumerate(files):
log.info(f"{os.path.split(file)[-1]}, {stripe_sigma[i]}")
converged = True
else:
log.info("Convergence reached! Final stripe sigma values:")
for i, file in enumerate(files):
log.info(f"{os.path.split(file)[-1]}: {stripe_sigma[i]}")
converged = True
iteration += 1
with open(step_complete_file, "w+") as f:
f.close()
return True
[docs]
def weighted_reproject_image(
self,
files,
procs=1,
do_large_scale=False,
):
"""Get reprojected images (and weights)
Args:
files (list): Files to reproject
procs (int): Number of processes to use. Defaults to 1.
do_large_scale: Is this a large-scale smoothed subtraction? Defaults to False
"""
data_reproj = [None] * len(files)
weight_reproj = [None] * len(files)
files_reproj = [None] * len(files)
log.info(f"Reprojecting images (and weights)")
with mp.get_context("fork").Pool(procs) as pool:
for i, result in tqdm(
pool.imap_unordered(
partial(
parallel_reproject_weight,
files=files,
optimal_wcs=self.optimal_wcs,
optimal_shape=self.optimal_shape,
weight_type=self.weight_type,
do_level_data=True,
reproject_func=self.reproject_func,
),
range(len(files)),
),
total=len(files),
desc="Weighted reprojects",
ascii=True,
):
files_reproj[i] = result[0]
data_reproj[i] = result[1]
weight_reproj[i] = result[2]
pool.close()
pool.join()
gc.collect()
self.files_reproj = files_reproj
# Create the weighted average image here
log.info("Creating average image")
self.data_avg = self.create_weighted_avg_image(data_reproj,
weight_reproj,
)
# Smooth out the data for large-scale correction
if do_large_scale:
log.info("Figuring out rotation between input images and stack")
ref_rot = get_rotation_angle(self.optimal_wcs)
indiv_rots = []
for file in self.files_reproj:
with datamodels.open(file) as im:
wcs = im.meta.wcs.to_fits_sip()
w = WCS(wcs)
indiv_rots.append(get_rotation_angle(w))
# Also get the filter scale, if necessary
if self.large_scale_filter_scale is None:
large_scale_filter_scale = im.data.shape[0] // 10
if large_scale_filter_scale % 2 == 0:
large_scale_filter_scale -= 1
self.large_scale_filter_scale = large_scale_filter_scale
indiv_rots = np.array(indiv_rots)
# Look for big differences between tiles
internal_diff = np.abs(indiv_rots - indiv_rots[0])
# Account for the quadrants of this space
internal_diff[internal_diff >= 180] -= 180
internal_diff = np.abs(internal_diff)
internal_diff[internal_diff >= 90] -= 180
internal_diff = np.abs(internal_diff)
# And differences wrt the reference WCS
ref_diff = np.abs(indiv_rots - ref_rot)
# Account for the quadrants of this space
ref_diff[ref_diff >= 180] -= 180
ref_diff = np.abs(ref_diff)
ref_diff[ref_diff >= 90] -= 180
ref_diff = np.abs(ref_diff)
# First case, we have a weird mix of rotations. In which case direction should be None
if not np.all(internal_diff < 10):
direction = None
log.info("Input images have a variety of rotations. Defaulting to smoothing over both axes")
# Second case, they're all similar to the reference rotation, so stripes are horizontal
elif np.all(ref_diff < 10):
direction = "horizontal"
log.info("Stacked image is at same rotation as input images")
# Third case, they're perpendicular to the reference rotation, so stripes are vertical
elif np.all(np.logical_and(80 < ref_diff, ref_diff < 100)):
direction = "vertical"
log.info("Stacked image is perpendicular to input images")
# Final case, they're aligned but not along any particular axis in the image. In this case,
# direction should be None
else:
direction = None
log.info("Stacked image does not align over a particular axis. Smoothing over both axes")
log.info("Creating smoothed image")
self.data_avg_smooth, self.data_avg_mask = self.get_data_avg_smooth(direction=direction)
return True
[docs]
def run_multi_tile_destripe(
self,
procs=1,
iteration=1,
do_large_scale=False,
):
"""Wrap parallelism around the multi-tile destriping
Args:
procs: Number of parallel processes. Defaults to 1
iteration: What iteration are we on? Defaults to 1
do_large_scale: Is this a large-scale smoothed subtraction? Defaults to False
"""
log.info("Running multi-tile destripe")
with mp.get_context("fork").Pool(procs) as pool:
results = [None] * len(self.files_reproj)
for i, result in tqdm(
pool.imap_unordered(
partial(
self.parallel_multi_tile_destripe,
iteration=iteration,
do_large_scale=do_large_scale,
),
range(len(self.files_reproj)),
),
total=len(self.files_reproj),
ascii=True,
desc="Multi-tile destriping",
):
results[i] = result
pool.close()
pool.join()
gc.collect()
stripe_sigma = np.zeros(len(self.files_reproj))
for result in results:
if result is not None:
in_file = result[0]
stripes = result[1]
short_file = os.path.split(in_file)[-1]
out_file = os.path.join(
self.out_dir,
short_file,
)
with datamodels.open(in_file) as im:
zero_idx = np.where(im.data == 0)
nan_idx = np.where(np.isnan(im.data))
if self.apply_to_unflat:
# Get CRDS context
try:
crds_context = os.environ["CRDS_CONTEXT"]
except KeyError:
crds_context = crds.get_default_context()
crds_dict = {
"INSTRUME": "NIRCAM",
"DETECTOR": im.meta.instrument.detector,
"FILTER": im.meta.instrument.filter,
"PUPIL": im.meta.instrument.pupil,
"DATE-OBS": im.meta.observation.date,
"TIME-OBS": im.meta.observation.time,
}
flats = crds.getreferences(crds_dict, reftypes=["flat"], context=crds_context)
flatfile = flats["flat"]
with datamodels.FlatModel(flatfile) as flat:
flat_inverse = copy.deepcopy(flat)
flat_inverse.data = flat_inverse.data ** -1
# First, unapply the flat fielding to the image and subtract stripes
unflattened_im, _ = do_correction(im, flat_inverse)
unflattened_im.data -= stripes
# Reflatten and save into the original data array
reflattened_im, _ = do_correction(unflattened_im, flat)
im.data = copy.deepcopy(reflattened_im.data)
else:
im.data -= stripes
# If we're not in subarray mode, here we want to level out between amplifiers
# for safety
if "sub" not in im.meta.subarray.name.lower():
im.data = level_data(im)
im.data[zero_idx] = 0
im.data[nan_idx] = np.nan
im.save(out_file)
# Find the right index, since multiprocessing doesn't
# preserve the order necessarily
idx = self.files_reproj.index(in_file)
# Calculate the maximum sigma-values for the stripes wrt error
err = copy.deepcopy(im.err)
err[err == 0] = np.nan
stripe_sigma[idx] = np.abs(np.nanmax(stripes / im.err))
del im
gc.collect()
return stripe_sigma
[docs]
def parallel_multi_tile_destripe(
self,
idx,
iteration=1,
do_large_scale=False,
):
"""Function to parallelise up multi-tile destriping
Args:
idx: Index of file to be destriped
iteration: What iteration are we on? Defaults to 1
do_large_scale: Is this a large-scale smoothed subtraction? Defaults to False
"""
file = self.files_reproj[idx]
result = self.multi_tile_destripe(
file,
iteration=iteration,
do_large_scale=do_large_scale,
)
return idx, result
[docs]
def create_weighted_avg_image(
self,
data,
weights,
):
"""Create an average image from a bunch of reprojected ones
Args:
data: List of data arrays
weights: List of weights. Should be same length as data
"""
# Start by calculating corrections to match between tiles, if not
# already done
if self.do_level_match:
offset_matrix = determine_offset_matrix(data)
corrections = solve_corrections_sgd(offset_matrix)
for array, correction in zip(data, corrections):
zero_idx = np.where(array.array == 0)
array.array -= correction
array.array[zero_idx] = 0
# First, put the original image in
if self.weight_method == "mean":
output_array = np.zeros(self.optimal_shape)
elif self.weight_method in ["median", "sigma_clip"]:
output_array = np.zeros([self.optimal_shape[0], self.optimal_shape[1], len(data)])
else:
raise ValueError(f"weight_method should be one of {ALLOWED_WEIGHT_METHODS}")
output_weights = np.zeros_like(output_array)
for i, (array, weight) in enumerate(zip(data, weights)):
# Put the reprojected data into the array. This will be different depending on
# the weight method
if self.weight_method == "mean":
output_array[array.view_in_original_array] += array.array * weight.array
output_weights[weight.view_in_original_array] += weight.array
elif self.weight_method in ["median", "sigma_clip"]:
output_array[array.view_in_original_array[0], array.view_in_original_array[1], i] = array.array
output_weights[weight.view_in_original_array[0], weight.view_in_original_array[1], i] = weight.array
else:
raise ValueError(
f"weight_method should be one of {ALLOWED_WEIGHT_METHODS}"
)
output_array[output_weights == 0] = np.nan
# Now we can calculate the average. For the mean, this is weighted
if self.weight_method == "mean":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data_avg = output_array / output_weights
# For the median, ignore weights (apart from 0s)
elif self.weight_method == "median":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data_avg = np.nanmedian(data, axis=-1)
# Sigma-clipped median (this will ignore weights)
elif self.weight_method == "sigma_clip":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data_avg = sigma_clipped_stats(output_array,
sigma=self.sigma,
maxiters=self.maxiters,
axis=-1,
)[1]
else:
raise ValueError(f"weight_method should be one of {ALLOWED_WEIGHT_METHODS}")
data_avg[data_avg == 0] = np.nan
data_avg[~np.isfinite(data_avg)] = np.nan
return data_avg
[docs]
def get_data_avg_smooth(self,
direction=None,
):
"""Filter data with a large scale filter
Will either perform a large-scale median filter over a specific axis,
or a mean filter over all axes. Also creates a mask
Args:
direction: Direction to smooth over, either "horizontal", "vertical", or None.
Defaults to None
"""
data_avg = copy.deepcopy(self.data_avg)
nan_idx = np.where(np.isnan(data_avg))
if direction in ["horizontal", None]:
interp_order = [1, 0]
elif direction == "vertical":
interp_order = [0, 1]
else:
raise ValueError("direction should be horizontal, vertical, or None")
# Extrapolate over axes in order based on where the stripes are
for order in interp_order:
for ax in range(data_avg.shape[order]):
if order == 0:
data_ax = copy.deepcopy(data_avg[ax, :])
else:
data_ax = copy.deepcopy(data_avg[:, ax])
mask = np.isnan(data_ax)
# Only interp if we have a) some NaNs but not b) all NaNs
if 0 < np.sum(mask) < len(data_ax):
data_ax[mask] = np.interp(np.flatnonzero(mask),
np.flatnonzero(~mask),
data_ax[~mask],
)
if order == 0:
data_avg[ax, :] = copy.deepcopy(data_ax)
else:
data_avg[:, ax] = copy.deepcopy(data_ax)
log.info(f"Smoothing with a filter scale of {self.large_scale_filter_scale}")
if direction is None:
data_smooth = uniform_filter(data_avg,
size=self.large_scale_filter_scale,
mode=self.large_scale_filter_extend_mode,
)
else:
data_smooth = np.zeros_like(data_avg)
if direction == "horizontal":
for row in range(data_avg.shape[1]):
col = data_avg[:, row]
col_smooth = median_filter(col,
size=self.large_scale_filter_scale,
mode=self.large_scale_filter_extend_mode,
)
data_smooth[:, row] = copy.deepcopy(col_smooth)
elif direction == "vertical":
for col in range(data_avg.shape[0]):
row = data_avg[col, :]
row_smooth = median_filter(row,
size=self.large_scale_filter_scale,
mode=self.large_scale_filter_extend_mode,
)
data_smooth[col, :] = copy.deepcopy(row_smooth)
else:
raise ValueError("direction should be one of horizontal, vertical")
data_smooth[nan_idx] = np.nan
mask = self.get_mask(self.data_avg - data_smooth)
return data_smooth, mask
[docs]
def get_mask(self,
data,
):
"""Create positive/negative mask"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mask_pos = make_source_mask(
data,
nsigma=self.sigma,
dilate_size=self.dilate_size,
sigclip_iters=self.maxiters,
)
mask_neg = make_source_mask(
-data,
mask=mask_pos,
nsigma=self.sigma,
dilate_size=self.dilate_size,
sigclip_iters=self.maxiters,
)
mask = mask_pos | mask_neg
return mask
[docs]
def multi_tile_destripe(
self,
file,
iteration=1,
do_large_scale=False,
):
"""Do a row-by-row, column-by-column data subtraction using other dither information
Reproject average image, optionally remove persistent large-scale stripes, then do a sigma-clipped
median along columns and rows (optionally by quadrants), and finally a smoothed clip along
rows after boxcar filtering to remove persistent large-scale ripples in the data
Args:
file (str): File to correct
iteration: What iteration are we on? Defaults to 1
do_large_scale: Is this a large-scale smoothed subtraction? Defaults to False
"""
if self.reproject_func == "interp":
r_func = reproject_interp
elif self.reproject_func == "exact":
r_func = reproject_exact
elif self.reproject_func == "adaptive":
r_func = reproject_adaptive
else:
raise ValueError(f"reproject_func should be one of {ALLOWED_REPROJECT_FUNCS}")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with datamodels.open(file) as model:
file_name = model.meta.filename
quadrants = copy.deepcopy(self.quadrants)
# If we're in subarray mode or doing large-scale, turn off quadrants
if "sub" in model.meta.subarray.name.lower():
quadrants = False
if do_large_scale:
quadrants = False
# If we're not in subarray mode, level everything out
else:
model.data = level_data(model)
dq_bit_mask = get_dq_bit_mask(model.dq)
# Pull out data and DQ mask
data = copy.deepcopy(model.data)
data[dq_bit_mask != 0] = np.nan
wcs = model.meta.wcs.to_fits_sip()
del model
# Reproject the average image
data_avg = r_func(
(self.data_avg, self.optimal_wcs),
wcs,
return_footprint=False,
)
# If we're also attempting to remove large-scale ripples, we filter the average
# data here and correct the average image
if do_large_scale:
# Reproject the smoothed data
data_avg_smooth = r_func(
(self.data_avg_smooth, self.optimal_wcs),
wcs,
return_footprint=False,
)
diff_smooth = data_avg - data_avg_smooth
# Also reproject the mask, casting to bool. This needs to use
# reproject_interp, so we can keep whole numbers
mask_smooth = reproject_interp(
(self.data_avg_mask, self.optimal_wcs),
wcs,
order='nearest-neighbor',
return_footprint=False,
)
mask_smooth = np.array(mask_smooth, dtype=bool)
# Get the low-level stripes left in the data
stripes_smooth = sigma_clipped_stats(
diff_smooth,
mask=mask_smooth,
sigma=self.sigma,
maxiters=self.maxiters,
axis=1,
)[1]
mask = np.isnan(stripes_smooth)
# Only interp if we have a) some NaNs but not b) all NaNs
if 0 < np.sum(mask) < len(stripes_smooth):
stripes_smooth[mask] = np.interp(np.flatnonzero(mask),
np.flatnonzero(~mask),
stripes_smooth[~mask],
)
data_avg = data_avg - stripes_smooth[:, np.newaxis]
diff = data - data_avg
diff -= np.nanmedian(diff)
stripes_arr = np.zeros_like(diff)
mask_diff = self.get_mask(diff)
if self.do_vertical_subtraction:
# First, subtract the y
stripes_y = sigma_clipped_stats(
diff - stripes_arr,
mask=mask_diff,
sigma=self.sigma,
maxiters=self.maxiters,
axis=0,
)[1]
# Centre around 0, replace NaNs with nearest value
stripes_y -= np.nanmedian(stripes_y)
mask = np.isnan(stripes_y)
# Only interp if we have a) some NaNs but not b) all NaNs
if 0 < np.sum(mask) < len(stripes_y):
stripes_y[mask] = np.interp(np.flatnonzero(mask),
np.flatnonzero(~mask),
stripes_y[~mask],
)
stripes_arr += stripes_y[np.newaxis, :]
stripes_x_2d = np.zeros_like(stripes_arr)
# Sigma-clip the diff across the whole image
stripes_x_full = sigma_clipped_stats(
diff - stripes_arr,
mask=mask_diff,
sigma=self.sigma,
maxiters=self.maxiters,
axis=1,
)[1]
stripes_x_full[stripes_x_full == 0] = np.nan
if quadrants:
quadrant_size = stripes_arr.shape[1] // 4
for quadrant in range(4):
idx_slice = slice(
quadrant * quadrant_size, (quadrant + 1) * quadrant_size
)
# Sigma-clip the diff
diff_quadrants = (
diff[:, idx_slice] - stripes_arr[:, idx_slice]
)
mask_quadrants = mask_diff[:, idx_slice]
stripes_x = sigma_clipped_stats(
diff_quadrants,
mask=mask_quadrants,
sigma=self.sigma,
maxiters=self.maxiters,
axis=1,
)[1]
stripes_x[stripes_x == 0] = np.nan
mask_sum = np.nansum(~np.asarray(mask_quadrants, dtype=bool), axis=1)
too_masked_idx = np.where(mask_sum < quadrant_size * self.min_mask_frac)
# For anything with less than the requisite amount of unmasked pixels, fall
# back to the full row median
stripes_x[too_masked_idx] = stripes_x_full[too_masked_idx]
# Replace NaNs with nearest values
mask = np.isnan(stripes_x)
# Only interp if we have a) some NaNs but not b) all NaNs
if 0 < np.sum(mask) < len(stripes_x):
stripes_x[mask] = np.interp(np.flatnonzero(mask),
np.flatnonzero(~mask),
stripes_x[~mask],
)
# Centre around 0, since we've corrected for steps between amplifiers
stripes_x -= np.nanmedian(stripes_x)
stripes_x_2d[:, idx_slice] += stripes_x[:, np.newaxis]
else:
# Centre around 0, replace NaNs with nearest values
stripes_x_full -= np.nanmedian(stripes_x_full)
mask = np.isnan(stripes_x_full)
# Only interp if we have a) some NaNs but not b) all NaNs
if 0 < np.sum(mask) < len(stripes_x_full):
stripes_x_full[mask] = np.interp(np.flatnonzero(mask),
np.flatnonzero(~mask),
stripes_x_full[~mask],
)
stripes_x_2d += stripes_x_full[:, np.newaxis]
# Centre around 0 one last time
stripes_x_2d -= np.nanmedian(stripes_x_2d)
stripes_arr += stripes_x_2d
stripes_arr -= np.nanmedian(stripes_arr)
# Make diagnostic plot. Use different names if
# we're iterating
suffix = "_multi_tile_destripe"
if do_large_scale:
suffix += "_large_scale"
if self.do_convergence:
suffix += f"_iter_{iteration}"
plot_name = os.path.join(
self.plot_dir,
file_name.replace(".fits", suffix),
)
make_diagnostic_plot(
plot_name=plot_name,
data=data,
stripes=stripes_arr,
)
return file, stripes_arr