import copy
import gc
import glob
import itertools
import logging
import multiprocessing as mp
import os
import random
import shutil
import warnings
from functools import partial
import astropy.units as u
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.stats import sigma_clipped_stats
from astropy.wcs import WCS
from reproject.mosaicking import find_optimal_celestial_wcs
from scipy.optimize import minimize
from stdatamodels.jwst import datamodels
from threadpoolctl import threadpool_limits
from tqdm import tqdm
from ..utils import get_dq_bit_mask, reproject_image, make_source_mask, make_stacked_image, get_band_type
# Rough lyot outline
LYOT_I = slice(735, None)
LYOT_J = slice(None, 290)
ALLOWED_REPROJECT_FUNCS = [
"interp",
"adaptive",
"exact",
]
ALLOWED_FIT_TYPES = [
"level",
"level+slope",
]
matplotlib.use("agg")
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 14
log = logging.getLogger(__name__)
def get_dithers(files,
combine_nircam_short=False,
):
"""Get unique dithers from a list of files
Args:
files: List of input files
combine_nircam_short: Whether to drop the chip number from
NIRCam short observations. Defaults to False
"""
# Split these into dithers per-chip
dithers = []
for file in files:
file_split = os.path.split(file)[-1].split("_")
dither_str = "_".join(file_split[:2]) + "_*_" + file_split[-2]
if combine_nircam_short:
dither_str = dither_str[:-1] + "*"
dithers.append(dither_str)
dithers = np.unique(dithers)
dithers.sort()
return dithers
def apply_subtraction(im,
delta,
fit_type="level",
ref_ra=0,
ref_dec=0,
ref_wcs=None,
ref_shape=None,
):
"""Apply subtraction to the image.
Args:
im: Input datamodel
delta: Coefficients to subtract from the image
fit_type: Type of fitting we've done. Options
are 'level' (the default, just fits a
single offset), and 'level+slope' (which will
fit a plane)
ref_ra: The reference RA point for the fits. Defaults to 0
ref_dec: The reference Dec point for the fits. Defaults to 0
ref_wcs: Reference WCS for the fits, since we do in pixel space.
If None, will assume this is the WCS for the image, which is
likely incorrect
ref_shape: Shape for the reference WCS, since we do the fits in
pixel space. If None, will assume the image shape, which is
likely incorrect
"""
zero_idx = im.data == 0
if fit_type == "level":
im.data -= delta[0]
elif fit_type == "level+slope":
# Here, pull out the WCS and use this to convert the delta
# coefficients to a slope for valid pixels
w = copy.deepcopy(im.meta.wcs.to_fits_sip())
wcs = WCS(w)
ii, jj = np.indices(im.data.shape, dtype=float)
if ref_wcs is None:
ref_wcs = copy.deepcopy(wcs)
if ref_shape is None:
ref_shape = copy.deepcopy(im.data.shape)
# This is a little fiddly, we need to convert to the proper reference frame and then map pixel
# coordinates
frame_coords = wcs.pixel_to_world(jj, ii)
ii_ref, jj_ref = np.indices(ref_shape)
ref_coords = ref_wcs.pixel_to_world(jj_ref, ii_ref)
# Convert the input world coordinates to the frame of the output world
# coordinates.
frame_coords = frame_coords.transform_to(ref_coords.frame)
# Compute the pixel positions in the *output* image of the pixels
# from the *input* image.
jj, ii = ref_wcs.world_to_pixel(frame_coords)
# Finally, subtract off the reference
ref_j, ref_i = get_x_y_values(ref_wcs, ref_ra, ref_dec)
jj -= ref_j
ii -= ref_i
delta_plane = delta[0] * jj + delta[1] * ii + delta[2]
im.data -= delta_plane
else:
raise ValueError(f"fit_type {fit_type} not known, should be one of {ALLOWED_FIT_TYPES}")
im.data[zero_idx] = 0
return im
def write_lvlmatch_meta(im, delta, fit_type):
"""Write level match coefficients to the ASDF metadata.
Must be called before im.save().
Args:
im: Open JWST datamodel
delta: Coefficients that were subtracted
fit_type: Type of fitting. 'level' or 'level+slope'
"""
if fit_type == "level":
im.meta.level_match = {
"fit_type": fit_type,
"offset": float(delta[0]),
}
elif fit_type == "level+slope":
im.meta.level_match = {
"fit_type": fit_type,
"x_slope": float(delta[0]),
"y_slope": float(delta[1]),
"offset": float(delta[2]),
}
def plane(x, y, params):
"""Define a plane of the form
params[0] * x + params[1] * y + params[2]
Args:
x: x coordinates
y: y coordinates
params: Parameters for the plane
"""
a = params[0]
b = params[1]
c = params[2]
z = a * x + b * y + c
return z
def plane_resid(params,
points,
err=None,
rescale_result=False,
return_sum=True,
):
"""Calculates the difference between a plane and input points
Args:
params: Parameters for the plane
points: (x, y, z) coordinates measured
err: Error on each z-point. Defaults to None
rescale_result: If True, will rescale the result
by the number of points so the chi-square
space isn't so severe. Defaults to False
return_sum: Whether to return the sum or all
the individual values. Defaults to True
"""
plane_z = plane(points[:, 0], points[:, 1], params)
diff = points[:, 2] - plane_z
if err is None:
result = diff ** 2
else:
result = diff ** 2 / err ** 2
if rescale_result:
# Scale chisq using sqrt(2N) to account for large number of points
result /= np.sqrt(2 * len(points[:, 0]))
if return_sum:
result = np.nansum(result)
return result
def cross(a, b):
return [a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0]]
def print_delta(file,
delta,
fit_type="level",
):
"""Format the delta for the log properly
Args:
file: Filename to apply the delta to
delta: Delta coefficients
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
if fit_type == "level":
log.info(f"{file}, delta={delta[0]:.2e}")
elif fit_type == "level+slope":
# Format pos/negs correctly in the string
if delta[1] < 0:
sign_dec = "-"
else:
sign_dec = "+"
if delta[2] < 0:
sign_dc = "-"
else:
sign_dc = "+"
log.info(f"{file}, plane={delta[0]:.2e} * x (pix) "
f"{sign_dec} {np.abs(delta[1]):.2e} * y (pix) "
f"{sign_dc} {np.abs(delta[2]):.2e}")
else:
raise ValueError(f"fit_type should be one of {ALLOWED_FIT_TYPES}")
def get_ra_dec_values(wcs,
jj,
ii,
units=u.arcsec,
return_coords=False,
):
"""Get RA/Dec values, given an input structure
Args:
wcs: Input WCS
jj: j coords
ii: i coords
units: Reference unit. Defaults to u.arcsec
return_coords: If True, will return SkyCoords.
Otherwise will convert
"""
coords = wcs.pixel_to_world(jj, ii)
if return_coords:
return coords
ra = coords.ra.to(units).value
dec = coords.dec.to(units).value
return ra, dec
def get_x_y_values(wcs,
ras,
decs,
units=u.arcsec,
):
"""Get x/y values, given an input structure
Args:
wcs: Input WCS
ras: RAs
decs: Decs
units: Reference unit. Defaults to u.arcsec
"""
c = SkyCoord(ras * units, decs * units)
x, y = wcs.world_to_pixel(c)
return x, y
[docs]
class LevelMatchStep:
def __init__(
self,
in_dir,
out_dir,
step_ext,
procs,
band,
fit_type_dithers="level",
fit_type_recombine_lyot="level",
fit_type_combine_nircam_short="level",
fit_type_mosaic_tiles="level",
recombine_lyot=False,
combine_nircam_short=False,
do_local_subtraction=True,
sigma=3,
npixels=3,
dilate_size=7,
max_iters=20,
max_points=10000,
do_sigma_clip=False,
weight_method="equal",
min_area_percent=0.002,
min_linear_frac=0.2,
rms_sig_limit=2,
reproject_func="interp",
overwrite=False,
):
"""Perform background matching between tiles
This step performs background matching by minimizing the
per-pixel differences between overlapping tiles. It does
this first for dither groups, before creating a stacked
image of these (to maximize areal coverage) and minimizing
between all stacked images within a mosaic. This is necessary
for observations that don't really have a background, and
performs significantly better than the JWST implementation.
N.B. If you use this, skymatch in the level 3 pipeline stage
should be global or off, to avoid undoing this work
Args:
in_dir: Input directory
out_dir: Output directory
step_ext: .fits extension for the files going
into the step
procs: Number of parallel processes to run
band: JWST band
fit_type_dithers: What kind of fit to do to match levels between
dithers in a single mosaic tile. Options are 'level' (the default,
just fits a single offset), and 'level+slope' (which will fit a plane)
fit_type_recombine_lyot: What kind of fit to do to match levels between
science and lyot in a single mosaic tile. Options are 'level'
(the default, just fits a single offset), and 'level+slope' (which will
fit a plane)
fit_type_combine_nircam_short: What kind of fit to do to match levels between
the four NIRCam short chips. Options are 'level' (the default, just fits
a single offset), and 'level+slope' (which will fit a plane)
fit_type_mosaic_tiles: What kind of fit to do to match levels between
mosaic tiles. Options are 'level' (the default, just fits a single offset),
and 'level+slope' (which will fit a plane)
recombine_lyot: If True, will recombine the lyot coronagraph
into the main chip after the initial round of level matching.
This will force the main science chip to have a 0 correction,
as the lyot seems to be more wobbly. Defaults to False
combine_nircam_short: Whether to combine the four NIRCam short
chips before matching in a mosaic. Defaults to False
do_local_subtraction: Whether to do a sigma-clipped local median
subtraction. Defaults to True
sigma: Sigma for sigma-clipping. Defaults to 3
npixels: Pixels to grow for masking. Defaults to 5
dilate_size: make_source_mask dilation size. Defaults to 7
max_iters: Maximum sigma-clipping iterations. Defaults to 20
max_points: Maximum points to include in histogram plots. This step can
be slow so this speeds it up. Defaults to 10000
do_sigma_clip: Whether to do sigma-clipping on data when reprojecting.
Defaults to False
weight_method: How to weight in least-squares minimization. Options are
'equal' (equal weighting), 'npix' (weight by number of pixels), and
'rms' (weight by rms of the delta values). Defaults to 'equal'
min_area_percent: Minimum percentage of average areal overlap to remove tiles.
Defaults to 0.002 (0.2%)
min_linear_frac: Minimum linear overlap in any direction to keep tiles.
Defaults to 0.2
rms_sig_limit: Sigma limit for cutting off noisy fits. Defaults to 2
reproject_func: Which reproject function to use. Defaults to 'interp',
but can also be 'exact' or 'adaptive'
overwrite: Whether to overwrite or not. Defaults
to False
"""
if reproject_func not in ALLOWED_REPROJECT_FUNCS:
raise ValueError(f"reproject_func should be one of {ALLOWED_REPROJECT_FUNCS}")
if fit_type_dithers not in ALLOWED_FIT_TYPES:
raise ValueError(f"fit_type_dithers should be one of {ALLOWED_FIT_TYPES}")
if fit_type_recombine_lyot not in ALLOWED_FIT_TYPES:
raise ValueError(f"fit_type_recombine_lyot should be one of {ALLOWED_FIT_TYPES}")
if fit_type_combine_nircam_short not in ALLOWED_FIT_TYPES:
raise ValueError(f"fit_type_combine_nircam_short should be one of {ALLOWED_FIT_TYPES}")
if fit_type_mosaic_tiles not in ALLOWED_FIT_TYPES:
raise ValueError(f"fit_type_mosaic_tiles should be one of {ALLOWED_FIT_TYPES}")
if do_local_subtraction and fit_type_dithers not in ["level"]:
log.warning("Cannot do local subtraction for methods beyond simple offset. Switching off")
do_local_subtraction = False
self.in_dir = in_dir
self.out_dir = out_dir
self.step_ext = step_ext
self.procs = procs
self.band = band
self.fit_type_dithers = fit_type_dithers
self.fit_type_recombine_lyot = fit_type_recombine_lyot
self.fit_type_combine_nircam_short = fit_type_combine_nircam_short
self.fit_type_mosaic_tiles = fit_type_mosaic_tiles
self.recombine_lyot = recombine_lyot
self.combine_nircam_short = combine_nircam_short
self.do_local_subtraction = do_local_subtraction
self.sigma = sigma
self.npixels = npixels
self.dilate_size = dilate_size
self.max_iters = max_iters
self.max_points = max_points
self.do_sigma_clip = do_sigma_clip
self.weight_method = weight_method
self.min_area_percent = min_area_percent
self.min_linear_frac = min_linear_frac
self.rms_sig_limit = rms_sig_limit
self.reproject_func = reproject_func
self.overwrite = overwrite
self.plot_dir = os.path.join(
self.out_dir,
"plots",
)
if not os.path.exists(self.plot_dir):
os.makedirs(self.plot_dir)
[docs]
def do_step(self):
"""Run level matching"""
if self.overwrite:
shutil.rmtree(self.out_dir)
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)
os.makedirs(self.plot_dir)
# Check if we've already run the step
step_complete_file = os.path.join(
self.out_dir,
"level_match_step_complete.txt",
)
if os.path.exists(step_complete_file):
log.info("Step already run")
return True
files = glob.glob(
os.path.join(
self.in_dir,
f"*_{self.step_ext}.fits",
)
)
files.sort()
dithers = get_dithers(files,
)
success = self.match_dithers(dithers,
fit_type=self.fit_type_dithers,
)
if not success:
log.warning("Failures detected level matching between individual dithers")
return False
if self.recombine_lyot:
# First, check we're in MIRI mode
band_type = get_band_type(self.band)
if band_type in ["miri"]:
success = self.match_lyot_science(dithers,
fit_type=self.fit_type_recombine_lyot,
)
if not success:
log.warning("Failures detected level matching between lyot and main science")
return False
# Redo the dithers, since now the "l" and "s" will have potentially been dropped
files = glob.glob(
os.path.join(
self.out_dir,
f"*_{self.step_ext}.fits",
)
)
files.sort()
# Split these into dithers per-chip
dithers = get_dithers(files)
if self.combine_nircam_short:
# First, check we're in NIRCam short mode
band_type, short_long_nircam = get_band_type(self.band,
short_long_nircam=True,
)
if short_long_nircam in ["nircam_short"]:
success = self.match_nircam_short(dithers,
fit_type=self.fit_type_combine_nircam_short,
)
if not success:
log.warning("Failures detected level matching between individual dithers")
return False
# Redo the dithers, since now we'll combine the 4 imaging chips for the short NIRCam
files = glob.glob(
os.path.join(
self.out_dir,
f"*_{self.step_ext}.fits",
)
)
files.sort()
# Split these into dithers per-chip
dithers = get_dithers(files,
combine_nircam_short=True,
)
if len(dithers) > 1:
success = self.match_mosaic_tiles(dithers,
fit_type=self.fit_type_mosaic_tiles,
)
if not success:
log.warning("Failures detected level matching between mosaic tiles")
return False
with open(step_complete_file, "w+") as f:
f.close()
return True
[docs]
def match_dithers(self,
dithers,
fit_type="level",
):
"""Match levels between the dithers in each mosaic tile
Args:
dithers: List of dither groups
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
if fit_type not in ALLOWED_FIT_TYPES:
raise ValueError(f"fit_type should be one of {ALLOWED_FIT_TYPES}")
# Ensure we're not wasting processes
procs = np.nanmin([self.procs, len(dithers)])
deltas, dither_files, ref_ras, ref_decs, optimal_wcses, optimal_shapes = self.get_per_dither_delta(
dithers=dithers,
fit_type=fit_type,
procs=procs,
)
# Apply this calculated value
for idx in range(len(deltas)):
deltas_idx = copy.deepcopy(deltas[idx])
dither_files_idx = copy.deepcopy(dither_files[idx])
ref_ra = copy.deepcopy(ref_ras[idx])
ref_dec = copy.deepcopy(ref_decs[idx])
optimal_wcs = copy.deepcopy(optimal_wcses[idx])
optimal_shape = copy.deepcopy(optimal_shapes[idx])
# If we're including a local subtraction, do it here
if self.do_local_subtraction:
with datamodels.open(dither_files_idx[0]) as im:
data = copy.deepcopy(im.data)
# Mask out bad data
dq_bit_mask = get_dq_bit_mask(im.dq)
data[dq_bit_mask != 0] = np.nan
data[data == 0] = np.nan
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mask = make_source_mask(
data,
nsigma=self.sigma,
dilate_size=self.dilate_size,
sigclip_iters=self.max_iters,
)
# Calculate sigma-clipped median
local_delta = sigma_clipped_stats(
data,
mask=mask,
sigma=self.sigma,
maxiters=self.max_iters,
)[1]
del im
else:
local_delta = 0
for i, dither_file in enumerate(dither_files_idx):
short_file = os.path.split(dither_file)[-1]
out_file = os.path.join(
self.out_dir,
short_file,
)
delta = copy.deepcopy(deltas_idx[i])
delta[-1] += local_delta
print_delta(file=short_file,
delta=delta,
fit_type=fit_type,
)
with datamodels.open(dither_file) as im:
im = apply_subtraction(im,
delta,
fit_type=fit_type,
ref_ra=ref_ra,
ref_dec=ref_dec,
ref_wcs=optimal_wcs,
ref_shape=optimal_shape,
)
write_lvlmatch_meta(im, delta, fit_type)
im.save(out_file)
del im
return True
[docs]
def match_lyot_science(self,
dithers,
fit_type="level",
):
"""Match levels between each individual lyot/main science chip, and recombine
Args:
dithers: List of dithers
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
log.info("Matching levels between lyot and main chip and recombining")
# From the individually corrected images
# get a stacked image
stacked_dir = self.out_dir + "_stacked"
if not os.path.exists(stacked_dir):
os.makedirs(stacked_dir)
procs = np.nanmin([self.procs, len(dithers)])
# We now want to find the separate l and s files
combined_miri_dithers = []
for dither in dithers:
if dither[-1] in ["l", "s"]:
combined_miri_dithers.append(dither[:-1])
combined_miri_dithers = np.unique(combined_miri_dithers)
if len(combined_miri_dithers) == 0:
log.info("No split MIRI lyot/main chip found. Returning")
return True
successes = self.make_stacked_images(
dithers=dithers,
stacked_dir=stacked_dir,
procs=procs,
)
if not np.all(successes):
log.warning("Failures detected making stacked images")
return False
procs = np.nanmin([self.procs, len(combined_miri_dithers)])
successes = []
with mp.get_context("fork").Pool(procs) as pool:
for success in tqdm(
pool.imap_unordered(
partial(
self.parallel_match_lyot_science,
stacked_dir=stacked_dir,
fit_type=fit_type,
),
combined_miri_dithers,
),
total=len(combined_miri_dithers),
desc="Matching lyot and main science",
ascii=True,
disable=True,
):
successes.append(success)
pool.close()
pool.join()
gc.collect()
# Remove the stacked images
shutil.rmtree(stacked_dir)
if not np.all(successes):
log.warning("Failures detected matching lyot and main science")
return False
return True
[docs]
def match_nircam_short(self,
dithers,
fit_type="level",
):
"""Match levels between the NIRCam short chips
Args:
dithers: List of dithers
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
log.info("Matching levels between NIRCam shorts")
# From the individually corrected images
# get a stacked image
stacked_dir = self.out_dir + "_stacked"
if not os.path.exists(stacked_dir):
os.makedirs(stacked_dir)
procs = np.nanmin([self.procs, len(dithers)])
# We now want to group up the four NIRCam imaging chips
combined_nircam_dithers = []
for dither in dithers:
combined_nircam_dithers.append(dither[:-1])
combined_nircam_dithers = np.unique(combined_nircam_dithers)
if len(combined_nircam_dithers) == 0:
log.info("No NIRCam shorts found. Returning")
return True
successes = self.make_stacked_images(
dithers=dithers,
stacked_dir=stacked_dir,
procs=procs,
)
if not np.all(successes):
log.warning("Failures detected making stacked images")
return False
procs = np.nanmin([self.procs, len(combined_nircam_dithers)])
successes = []
with mp.get_context("fork").Pool(procs) as pool:
for success in tqdm(
pool.imap_unordered(
partial(
self.parallel_match_nircam_short,
stacked_dir=stacked_dir,
fit_type=fit_type,
),
combined_nircam_dithers,
),
total=len(combined_nircam_dithers),
desc="Matching NIRCAM short chips",
ascii=True,
disable=True,
):
successes.append(success)
pool.close()
pool.join()
gc.collect()
# Remove the stacked images
shutil.rmtree(stacked_dir)
if not np.all(successes):
log.warning("Failures detected matching NIRCam short chips")
return False
return True
[docs]
def match_mosaic_tiles(self,
dithers,
fit_type="level",
):
"""Match levels between each mosaic tile
Args:
dithers: List of dither groups
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
log.info("Matching levels between mosaic tiles")
# From the individually corrected images
# get a stacked image
stacked_dir = self.out_dir + "_stacked"
if not os.path.exists(stacked_dir):
os.makedirs(stacked_dir)
procs = np.nanmin([self.procs, len(dithers)])
successes = self.make_stacked_images(
dithers=dithers,
stacked_dir=stacked_dir,
procs=procs,
)
if not np.all(successes):
log.warning("Failures detected making stacked images")
return False
# Now match up these stacked images
stacked_files = glob.glob(
os.path.join(stacked_dir, f"*_{self.step_ext}.fits")
)
stacked_files.sort()
# From these files, select a reference image based on the closest background,
# if possible
bgr_times = np.zeros(len(stacked_files))
bgr_times[bgr_times == 0] = np.nan
for i, f in enumerate(stacked_files):
with fits.open(f) as hdu:
if "DT_BGR" in hdu[0].header:
bgr_times[i] = float(hdu[0].header["DT_BGR"])
if np.all(np.isnan(bgr_times)):
# If we don't have any background times, let it select automatically
ref_idx = 0
log.info("Will use first image as the reference image")
else:
# Get the one closest in time to the backgrounds
ref_idx = np.nanargmin(np.abs(bgr_times))
# If it's selected a lyot image, force it to be the science
tidy_file = os.path.split(stacked_files[ref_idx])[-1]
full_file = stacked_files[ref_idx]
if "mirimagel" in tidy_file:
tidy_file = tidy_file.replace("mirimagel", "mirimages")
full_file = full_file.replace("mirimagel", "mirimages")
ref_idx = stacked_files.index(full_file)
log.info(f"Will use {tidy_file} as the reference image")
(
delta_matrix,
npix_matrix,
rms_matrix,
lin_size_matrix,
valid_matrix,
optimal_wcs,
optimal_shape,
ref_ra,
ref_dec,
) = self.calculate_delta(
stacked_files,
fit_type=fit_type,
stacked_image=True,
procs=procs,
)
deltas = self.find_optimum_deltas(
delta_mat=delta_matrix,
npix_mat=npix_matrix,
rms_mat=rms_matrix,
lin_size_mat=lin_size_matrix,
valid_mat=valid_matrix,
fit_type=fit_type,
ref_idx=ref_idx,
)
# Subtract the per-dither delta, do in place
for idx, delta in enumerate(deltas):
short_stack_file = os.path.split(stacked_files[idx])[-1]
dither_files = glob.glob(
os.path.join(
self.out_dir,
short_stack_file,
)
)
dither_files.sort()
print_delta(file=short_stack_file,
delta=delta,
fit_type=fit_type,
)
for dither_file in dither_files:
with datamodels.open(dither_file) as im:
im = apply_subtraction(im,
delta,
fit_type=fit_type,
ref_ra=ref_ra,
ref_dec=ref_dec,
ref_wcs=optimal_wcs,
ref_shape=optimal_shape,
)
write_lvlmatch_meta(im, delta, fit_type)
im.save(dither_file)
del im
# Remove the stacked images
shutil.rmtree(stacked_dir)
return True
[docs]
def get_per_dither_delta(
self,
dithers,
fit_type="level",
procs=1,
):
"""Function to parallelise getting the delta for each observation in a dither sequence
Args:
dithers: List of dithers to get deltas for
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
procs: Number of processes to run simultaneously. Defaults
to 1
"""
log.info("Getting deltas for individual dithers")
deltas = []
dither_files = []
ref_ras = []
ref_decs = []
optimal_wcses = []
optimal_shapes = []
with mp.get_context("fork").Pool(procs) as pool:
for delta, dither_file, ref_ra, ref_dec, optimal_wcs, optimal_shape in tqdm(
pool.imap_unordered(
partial(
self.parallel_per_dither_delta,
fit_type=fit_type,
),
dithers,
),
total=len(dithers),
desc="Matching individual dithers",
ascii=True,
):
deltas.append(delta)
dither_files.append(dither_file)
ref_ras.append(ref_ra)
ref_decs.append(ref_dec)
optimal_wcses.append(optimal_wcs)
optimal_shapes.append(optimal_shape)
pool.close()
pool.join()
gc.collect()
return deltas, dither_files, ref_ras, ref_decs, optimal_wcses, optimal_shapes
[docs]
def parallel_per_dither_delta(
self,
dither,
fit_type="level",
):
"""Function to parallelise up matching dithers
Args:
dither: Input dither
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
dither_files = glob.glob(
os.path.join(
self.in_dir,
f"{dither}*_{self.step_ext}.fits",
)
)
dither_files.sort()
with threadpool_limits(limits=1, user_api=None):
(
delta_matrix,
npix_matrix,
rms_matrix,
lin_size_matrix,
valid_matrix,
optimal_wcs,
optimal_shape,
ref_ra,
ref_dec,
) = self.calculate_delta(dither_files,
fit_type=fit_type,
)
# If we're in a weird edge case where we only have one dither, jump out and just return 0s
if len(dither_files) > 1:
deltas = self.find_optimum_deltas(
delta_mat=delta_matrix,
npix_mat=npix_matrix,
rms_mat=rms_matrix,
lin_size_mat=lin_size_matrix,
valid_mat=valid_matrix,
fit_type=fit_type,
)
else:
deltas = np.zeros([delta_matrix.shape[0], delta_matrix.shape[-1]])
return deltas, dither_files, ref_ra, ref_dec, optimal_wcs, optimal_shape
[docs]
def make_stacked_images(
self,
dithers,
stacked_dir,
procs=1,
):
"""Function to parallellise up making stacked dither images
Args:
dithers: List of dithers to go
stacked_dir: Where to save stacked images to
procs: Number of simultaneous processes to run.
Defaults to 1
"""
log.info("Creating stacked images")
with mp.get_context("fork").Pool(procs) as pool:
successes = []
for success in tqdm(
pool.imap_unordered(
partial(
self.parallel_make_stacked_image,
out_dir=stacked_dir,
),
dithers,
),
total=len(dithers),
desc="Creating stacked images",
ascii=True,
):
successes.append(success)
pool.close()
pool.join()
gc.collect()
return successes
[docs]
def parallel_make_stacked_image(
self,
dither,
out_dir,
):
"""Light wrapper around parallelising the stacked image
Args:
dither: Dither to stack
out_dir: Directory to save to
"""
files = glob.glob(
os.path.join(
self.out_dir,
f"{dither}*_{self.step_ext}.fits",
)
)
files.sort()
# Create output name
file_name_split = os.path.split(files[0])[-1].split("_")
file_name_split[2] = "*"
out_name = "_".join(file_name_split)
out_name = os.path.join(out_dir, out_name)
# Make the stacked image. Set auto-rotate True to minimize the image shape
success = make_stacked_image(
files=files,
out_name=out_name,
additional_hdus="ERR",
reproject_func=self.reproject_func,
auto_rotate=True,
)
if not success:
return False
return True
[docs]
def parallel_match_lyot_science(self,
dither,
stacked_dir,
fit_type
):
"""Function to parallelise up combining the lyot back into the main science chip
Because the lyot seems to behave a little weirdly in its backgrounds from time-to-time,
we force the main science correction to be 0 and put that all into the lyot
Args:
dither: Dither to level match and combine
stacked_dir: Directory contained stacked images
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
stacked_files = glob.glob(
os.path.join(stacked_dir, f"{dither}*_{self.step_ext}.fits")
)
stacked_files.sort()
(
delta_matrix,
npix_matrix,
rms_matrix,
lin_size_matrix,
valid_matrix,
optimal_wcs,
optimal_shape,
ref_ra,
ref_dec,
) = self.calculate_delta(
stacked_files,
fit_type=fit_type,
stacked_image=True,
procs=None,
)
lyot_delta, sci_delta = self.find_optimum_deltas(
delta_mat=delta_matrix,
npix_mat=npix_matrix,
rms_mat=rms_matrix,
lin_size_mat=lin_size_matrix,
valid_mat=valid_matrix,
fit_type=fit_type,
ref_idx=1, # Since we want to base the correction on the science files
)
# Subtract the per-dither delta. Since we've sorted, the first is lyot and the second is
# the main science
short_stack_file = os.path.split(stacked_files[1])[-1]
sci_files = glob.glob(
os.path.join(
self.out_dir,
short_stack_file,
)
)
sci_files.sort()
short_file = short_stack_file.replace("mirimages", "mirimage")
print_delta(short_file,
lyot_delta,
fit_type=fit_type,
)
for sci_file in sci_files:
# Get the equivalent lyot file, and the out file (dropping that s/l)
lyot_file = sci_file.replace("mirimages", "mirimagel")
out_file = sci_file.replace("mirimages", "mirimage")
with datamodels.open(sci_file) as sci_im, datamodels.open(lyot_file) as lyot_im:
# Do the subtraction, only needed on the lyot since we force the science correction
# to 0
lyot_im = apply_subtraction(lyot_im,
lyot_delta,
fit_type=fit_type,
ref_ra=ref_ra,
ref_dec=ref_dec,
ref_wcs=optimal_wcs,
ref_shape=optimal_shape,
)
# Force the lyot back into the science
sci_im.data[LYOT_I, LYOT_J] = lyot_im.data[LYOT_I, LYOT_J]
sci_im.dq[LYOT_I, LYOT_J] = lyot_im.dq[LYOT_I, LYOT_J]
# Save
write_lvlmatch_meta(sci_im, lyot_delta, fit_type)
sci_im.save(out_file)
del sci_im, lyot_im
# And finally, remove the two separate files to clean things up
os.system(f"rm -rf {sci_file}")
os.system(f"rm -rf {lyot_file}")
return True
[docs]
def parallel_match_nircam_short(self,
dither,
stacked_dir,
fit_type
):
"""Function to parallelise up matching levels between the short NIRCam chips
Args:
dither: Dither to level match and combine
stacked_dir: Directory contained stacked images
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
"""
stacked_files = glob.glob(
os.path.join(stacked_dir, f"{dither}*_{self.step_ext}.fits")
)
stacked_files.sort()
(
delta_matrix,
npix_matrix,
rms_matrix,
lin_size_matrix,
valid_matrix,
optimal_wcs,
optimal_shape,
ref_ra,
ref_dec,
) = self.calculate_delta(
stacked_files,
fit_type=fit_type,
stacked_image=True,
procs=None,
)
deltas = self.find_optimum_deltas(
delta_mat=delta_matrix,
npix_mat=npix_matrix,
rms_mat=rms_matrix,
lin_size_mat=lin_size_matrix,
valid_mat=valid_matrix,
fit_type=fit_type,
)
# Subtract the per-dither delta, do in place
for idx, delta in enumerate(deltas):
short_stack_file = os.path.split(stacked_files[idx])[-1]
dither_files = glob.glob(
os.path.join(
self.out_dir,
short_stack_file,
)
)
dither_files.sort()
print_delta(file=short_stack_file,
delta=delta,
fit_type=fit_type,
)
for dither_file in dither_files:
with datamodels.open(dither_file) as im:
im = apply_subtraction(im,
delta,
fit_type=fit_type,
ref_ra=ref_ra,
ref_dec=ref_dec,
ref_wcs=optimal_wcs,
ref_shape=optimal_shape,
)
write_lvlmatch_meta(im, delta, fit_type)
im.save(dither_file)
del im
return True
[docs]
def calculate_delta(
self,
files,
fit_type="level",
stacked_image=False,
procs=None,
):
"""Match relative offsets between tiles
Args:
files (list): List of files to match
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
stacked_image: Whether this is a stacked image or not.
Default to False
procs (int, optional): Number of processes to run in
parallel. Defaults to None, which is series
"""
if fit_type == "level":
n_coeff = 1
elif fit_type == "level+slope":
n_coeff = 3
else:
raise ValueError(f"fit_type should be one of {ALLOWED_FIT_TYPES}")
deltas = np.zeros([len(files), len(files), n_coeff])
weights = np.zeros([len(files), len(files)])
rmses = np.zeros_like(weights)
lin_sizes = np.ones_like(weights)
valid_mat = np.ones_like(weights)
for i in range(len(files)):
valid_mat[i, i] = 0
# Reproject all the HDUs. Start by building the optimal WCS
if isinstance(files[0], list):
files_flat = list(itertools.chain(*files))
else:
files_flat = copy.deepcopy(files)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Get optimal WCS
optimal_wcs, optimal_shape = find_optimal_celestial_wcs(
files_flat,
hdu_in="SCI",
auto_rotate=True,
)
# The reference pixels are the centre of the shape, but we also want this as RA/Dec
ref_y, ref_x = np.asarray(optimal_shape) // 2
ref_ra, ref_dec = get_ra_dec_values(optimal_wcs, ref_x, ref_y)
if procs is None:
# Use a serial method
# Reproject files, maintaining structure
file_reproj = []
for file in files:
file_reproj.append(
self.get_reproject(
file=file,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
stacked_image=stacked_image,
)
)
for i in range(len(files)):
for j in range(i + 1, len(files)):
plot_name = self.get_plot_name(
files[i],
files[j],
)
n_pix, delta, rms, lin_size = self.get_level_match(
files1=file_reproj[i],
files2=file_reproj[j],
fit_type=fit_type,
ref_x=ref_x,
ref_y=ref_y,
plot_name=plot_name,
)
# These are symmetrical by design, but anything where we don't have values is invalid
if n_pix == 0 or delta is None or rms is None:
valid_mat[i, j] = 0
valid_mat[j, i] = 0
continue
for n in range(n_coeff):
deltas[j, i, n] = delta[n]
deltas[i, j, n] = -delta[n]
weights[j, i] = n_pix
rmses[j, i] = rms
lin_sizes[j, i] = lin_size
weights[i, j] = n_pix
rmses[i, j] = rms
lin_sizes[i, j] = lin_size
gc.collect()
else:
# We can multiprocess this, since each calculation runs independently
n_procs = np.nanmin([procs, len(files)])
with mp.get_context("fork").Pool(n_procs) as pool:
file_reproj = list([None] * len(files))
for i, result in tqdm(
pool.imap_unordered(
partial(
self.parallel_get_reproject,
files=files,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
stacked_image=stacked_image,
),
range(len(files)),
),
total=len(files),
desc="Reprojecting files",
ascii=True,
):
file_reproj[i] = result
pool.close()
pool.join()
gc.collect()
all_ijs = [
(i, j) for i in range(len(files)) for j in range(i + 1, len(files))
]
ijs = []
delta_vals = []
n_pix_vals = []
rms_vals = []
lin_size_vals = []
for ij in tqdm(all_ijs, ascii=True, desc="Calculating delta matrix"):
ij, delta, n_pix, rms, lin_size = self.parallel_delta_matrix(
ij=ij,
files=files,
fit_type=fit_type,
file_reproj=file_reproj,
ref_x=ref_x,
ref_y=ref_y,
)
ijs.append(ij)
delta_vals.append(delta)
n_pix_vals.append(n_pix)
rms_vals.append(rms)
lin_size_vals.append(lin_size)
for idx, ij in enumerate(ijs):
i = ij[0]
j = ij[1]
if n_pix_vals[idx] == 0 or delta_vals[idx] is None or rms_vals[idx] is None:
valid_mat[i, j] = 0
valid_mat[j, i] = 0
continue
for n in range(n_coeff):
deltas[j, i, n] = delta_vals[idx][n]
deltas[i, j, n] = -delta_vals[idx][n]
weights[j, i] = n_pix_vals[idx]
rmses[j, i] = rms_vals[idx]
lin_sizes[j, i] = lin_size_vals[idx]
weights[i, j] = n_pix_vals[idx]
rmses[i, j] = rms_vals[idx]
lin_sizes[i, j] = lin_size_vals[idx]
gc.collect()
return deltas, weights, rmses, lin_sizes, valid_mat, optimal_wcs, optimal_shape, ref_ra, ref_dec
[docs]
def parallel_delta_matrix(
self,
ij,
file_reproj,
files,
fit_type="level",
ref_x=0,
ref_y=0,
):
"""Function to parallelise up getting delta matrix values
Args:
ij: List of matrix (i, j) values
file_reproj: Reprojected file
files: Full list of files
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
ref_x: Reference x value to subtract to centre the fits. Defaults to 0
ref_y: Reference y value to subtract to centre the fits. Defaults to 0
"""
i = ij[0]
j = ij[1]
plot_name = None
if self.plot_dir is not None:
plot_name = self.get_plot_name(
files1=files[i],
files2=files[j],
)
with threadpool_limits(limits=1, user_api=None):
n_pix, delta, rms, lin_size = self.get_level_match(
files1=file_reproj[i],
files2=file_reproj[j],
fit_type=fit_type,
ref_x=ref_x,
ref_y=ref_y,
plot_name=plot_name,
)
gc.collect()
return ij, delta, n_pix, rms, lin_size
[docs]
def get_level_match(
self,
files1,
files2,
fit_type,
ref_x=0,
ref_y=0,
plot_name=None,
maxiters=10,
plane_fit_maxiters=20,
plane_fit_abs_tol=1e-8,
plane_fit_rel_tol=1e-5,
):
"""Calculate relative difference between groups of files on the same pixel grid
Args:
files1: List of files to get difference from
files2: List of files to get relative difference to
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
ref_x: Reference x-coord to perform fits relative to. Defaults to 0
ref_y: Reference y-coord to perform fits relative to. Defaults to 0
plot_name: Output plot name. Defaults to None
maxiters: Maximum iterations for the sigma-clipping. Defaults
to 10
plane_fit_maxiters: Maximum number of iterations for the plane fitting.
Defaults to 20
plane_fit_abs_tol: Absolute tolerance to define convergence in the
plane fit. Defaults to 1e-8
plane_fit_rel_tol: Relative tolerance to define convergence in
the plane fitting. Defaults to 1e-5
"""
diffs = []
errs = []
iis = []
jjs = []
if not isinstance(files1, list):
files1 = [files1]
if not isinstance(files2, list):
files2 = [files2]
n_pix = 0
fig, axs = None, None
if plot_name is not None:
fig, axs = plt.subplots(
nrows=len(files1),
ncols=len(files2),
figsize=(2.5 * len(files2), 2.5 * len(files1)),
squeeze=False,
)
lin_size = 0
for file_idx1, file1 in enumerate(files1):
# Get out coordinates where data is valid, so we can do a linear
# extent test
file1_data = copy.deepcopy(file1["data"])
file1_err = copy.deepcopy(file1["err"])
ii, jj = np.indices(file1_data.array.shape, dtype=float)
nan_idx = np.where(np.isnan(file1_data.array) | (file1_data.footprint == 0))
ii[nan_idx] = np.nan
jj[nan_idx] = np.nan
# If we have something that's all NaNs
# (e.g. lyot on MIRI subarray obs.), skip
if np.all(np.isnan(ii)):
continue
file1_iaxis = np.nanmax(ii) - np.nanmin(ii)
file1_jaxis = np.nanmax(jj) - np.nanmin(jj)
for file_idx2, file2 in enumerate(files2):
file2_data = copy.deepcopy(file2["data"])
file2_err = copy.deepcopy(file2["err"])
if file2_data.overlaps(file1_data):
# Get diffs, remove NaNs
diff = file2_data - file1_data
diff_arr = diff.array
diff_foot = diff.footprint
diff_arr[diff_foot == 0] = np.nan
# Pull out error arrays, remove NaNs
err = file1_err * file1_err + file2_err * file2_err
err_arr = np.sqrt(err.array)
err_arr[diff_foot == 0] = np.nan
# Get out coordinates where data is valid, so we can do a linear
# extent test
nan_idx = np.where(np.isnan(file2_data.array) | (file2_data.footprint == 0))
ii, jj = np.indices(file2_data.array.shape, dtype=float)
ii[nan_idx] = np.nan
jj[nan_idx] = np.nan
# If we have something that's all NaNs
# (e.g. lyot on MIRI subarray obs.), skip
if np.all(np.isnan(ii)):
continue
file2_iaxis = np.nanmax(ii) - np.nanmin(ii)
file2_jaxis = np.nanmax(jj) - np.nanmin(jj)
ii, jj = np.indices(diff_arr.shape, dtype=float)
ii[np.where(np.isnan(diff_arr))] = np.nan
jj[np.where(np.isnan(diff_arr))] = np.nan
# If the slices are all NaNs, then we can just move on
if np.all(np.isnan(ii)):
continue
diff_iaxis = np.nanmax(ii) - np.nanmin(ii)
diff_jaxis = np.nanmax(jj) - np.nanmin(jj)
# Include everything, but flag if we've hit the minimum linear extent
if (
diff_iaxis > file1_iaxis * self.min_linear_frac
or diff_jaxis > file1_jaxis * self.min_linear_frac
or diff_iaxis > file2_iaxis * self.min_linear_frac
or diff_jaxis > file2_jaxis * self.min_linear_frac
):
lin_size = 1
# Get valid points, which are where we have finite diffs/errs and the error array isn't 0
valid_idx = np.isfinite(diff_arr) & np.isfinite(err_arr) & (err_arr != 0)
# Get the coords, account for the differences in where the arrays start
diff_ii = ii[valid_idx] + diff.bounds[1][0]
diff_jj = jj[valid_idx] + diff.bounds[0][0]
diff = diff_arr[valid_idx].tolist()
err = err_arr[valid_idx].tolist()
n_pix += len(diff)
diffs.extend(diff)
errs.extend(err)
iis.extend(diff_ii)
jjs.extend(diff_jj)
if plot_name is not None:
if len(diff) > 0:
vmin, vmax = np.nanpercentile(diff, [1, 99])
axs[file_idx1, file_idx2].imshow(
diff_arr,
origin="lower",
vmin=vmin,
vmax=vmax,
interpolation="nearest",
)
if plot_name is not None:
axs[file_idx1, file_idx2].set_axis_off()
if plot_name is not None:
if n_pix > 0:
plt.savefig(f"{plot_name}_ims.png", bbox_inches="tight")
plt.savefig(f"{plot_name}_ims.pdf", bbox_inches="tight")
plt.close()
if n_pix > 0:
if fit_type == "level":
# Just fit a DC offset. Sigma-clip to remove outliers in the distribution
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_, delta, rms = sigma_clipped_stats(diffs, sigma=self.sigma, maxiters=maxiters)
if plot_name is not None:
# Get histogram range
diffs_hist = None
if self.max_points is not None:
if len(diffs) > self.max_points:
diffs_hist = random.sample(diffs, self.max_points)
if diffs_hist is None:
diffs_hist = copy.deepcopy(diffs)
hist_range = np.nanpercentile(diffs_hist, [1, 99])
plt.figure(figsize=(5, 4))
plt.hist(
diffs_hist,
histtype="step",
bins=50,
range=hist_range,
color="gray",
)
plt.axvline(
delta,
c="k",
ls="--",
)
plt.xlabel("Diff (MJy/sr)")
plt.ylabel("$N$")
plt.grid()
plt.tight_layout()
plt.savefig(f"{plot_name}_hist.pdf", bbox_inches="tight")
plt.savefig(f"{plot_name}_hist.png", bbox_inches="tight")
plt.close()
# For consistency with how we'll do things later, put this into a list
delta = [delta]
elif fit_type == "level+slope":
diffs = np.array(diffs)
errs = np.array(errs)
# For this, we do a slope fit in the x, y, rather than RA/Dec plane
# to avoid any potential spherical weirdness
iis = np.array(iis)
jjs = np.array(jjs)
jjs -= ref_x
iis -= ref_y
# Also remove a mean x/y, to avoid fitting spurious correlations
mean_j = np.nanmean(jjs)
mean_i = np.nanmean(iis)
jjs -= mean_j
iis -= mean_i
initial_offset = np.nanmedian(diffs)
# Get an initial guess which we'll normalise things by. This is just a flat plane
delta = np.array([0, 0, initial_offset], dtype=float)
converged = False
n_iter = 0
while not converged and n_iter <= plane_fit_maxiters:
prev_delta = copy.deepcopy(delta)
# Look at the current plane we have, and reject points that are
# significantly different to it
delta_plane = plane(jjs, iis, delta)
rms = np.nanstd(diffs - delta_plane)
fit_idx = np.where(np.abs(diffs - delta_plane) < self.sigma * rms)
# Do a fit to the residuals, using only the points we care about
points = np.vstack(
(
jjs[fit_idx],
iis[fit_idx],
diffs[fit_idx] - delta_plane[fit_idx],
)
).T
func = partial(plane_resid,
points=points,
err=errs[fit_idx],
rescale_result=True,
)
res = minimize(func,
delta,
method="Powell",
)
delta_diff = copy.deepcopy(res.x)
# Add this to our final delta
delta += delta_diff
# If the changes are very small from the previous delta, then just call this converged and jump out
if np.all(
np.isclose(
delta,
prev_delta,
atol=plane_fit_abs_tol,
rtol=plane_fit_rel_tol,
)
):
converged = True
log.debug(f"Plane fitting converged after {n_iter} iterations")
n_iter += 1
if not converged:
log.debug(f"Plane fitting did not converge after {n_iter - 1} iterations")
ii_min, ii_max = np.nanmin(iis), np.nanmax(iis)
jj_min, jj_max = np.nanmin(jjs), np.nanmax(jjs)
# Get residuals from the best plane
best_plane = plane(
jjs,
iis,
delta,
)
resid = diffs - best_plane
# Get a measure of the RMS
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_, _, rms = sigma_clipped_stats(resid,
sigma=self.sigma,
maxiters=maxiters,
)
# Get histogram range for residuals
resid_hist = None
if self.max_points is not None:
if len(resid) > self.max_points:
resid_hist = random.sample(list(resid), self.max_points)
if resid_hist is None:
resid_hist = copy.deepcopy(resid)
# And also points to scatter
scatter_idx = None
if self.max_points is not None:
if len(resid) > self.max_points:
scatter_idx = np.array(random.sample(range(len(resid)), self.max_points))
if scatter_idx is None:
scatter_idx = slice(None)
hist_range = np.nanpercentile(resid, [1, 99])
# Get the plane to show the fit
xx, yy = np.meshgrid([jj_min, jj_max], [ii_min, ii_max])
z = delta[0] * xx + delta[1] * yy + delta[2]
# Make a plot that shows the plane fit through the points on the left, and the
# residuals on the right
fig = plt.figure(figsize=(8, 4))
ax_3d = fig.add_subplot(1, 2, 1, projection='3d')
ax_3d.set_proj_type('ortho')
ax_3d.view_init(elev=30, azim=45, roll=0)
ax_3d.scatter(
jjs[scatter_idx],
iis[scatter_idx],
diffs[scatter_idx],
c='k',
marker='.',
alpha=0.1,
rasterized=True,
)
ax_3d.plot_surface(xx, yy, z, alpha=0.4, color='red')
ax_3d.set_xlim(jj_min, jj_max)
ax_3d.set_ylim(ii_min, ii_max)
ax_3d.ticklabel_format(useOffset=False)
if ref_x is not None:
x_label = r"$\Delta x$"
else:
x_label = r"$x$"
if ref_y is not None:
y_label = r"$\Delta y$"
else:
y_label = r"$y$"
ax_3d.set_xlabel(f"{x_label} (pix)")
ax_3d.set_ylabel(f"{y_label} (pix)")
ax_3d.set_zlabel("Diff (MJy/sr)")
ax_3d.xaxis.labelpad = 10
ax_3d.yaxis.labelpad = 10
ax_resid = fig.add_subplot(1, 2, 2)
ax_resid.hist(
resid_hist,
histtype="step",
bins=50,
range=hist_range,
color="gray",
)
ax_resid.set_xlabel("Residual (MJy/sr)")
ax_resid.set_ylabel("$N$")
ax_resid.yaxis.set_label_position("right")
ax_resid.yaxis.tick_right()
ax_resid.grid()
plt.tight_layout()
plt.savefig(f"{plot_name}_plane_fit.pdf", bbox_inches="tight", dpi=300)
plt.savefig(f"{plot_name}_plane_fit.png", bbox_inches="tight", dpi=300)
plt.close()
# Translate this back to the central coordinate, since we subtracted that off earlier
z_offset = delta[0] * mean_j + delta[1] * mean_i
delta[-1] -= z_offset
else:
raise ValueError(f"fit_type should be one of {ALLOWED_FIT_TYPES}")
else:
delta = None
rms = None
gc.collect()
return n_pix, delta, rms, lin_size
[docs]
def find_optimum_deltas(
self,
delta_mat,
npix_mat,
rms_mat,
lin_size_mat,
valid_mat,
fit_type="level",
n_draws=25,
n_iter=10000,
convergence_abs_tol=1e-8,
convergence_rel_tol=1e-5,
ref_idx=None,
):
"""Get optimum deltas from a delta/weight matrix.
Taken from the JWST skymatch step, with some edits to remove potentially bad fits due
to small areal overlaps, or noisy diffs, and various weighting schemes.
If we're fitting a plane, delta_mat will be an NxNx3 matrix, and we'll minimize over
each of the last axes separately
Args:
delta_mat (np.ndarray): Matrix of delta values. These may be [a, b, c] coefficients
if we're fitting a plane
npix_mat (np.ndarray): Matrix of number of pixel values for calculating delta
rms_mat (np.ndarray): Matrix of RMS values
lin_size_mat (np.ndarray): 1/0 array for whether overlaps pass minimum linear extent
valid_mat (np.ndarray): 1/0 array for whether overlaps are valid
fit_type: Which type of fit to do. See ALLOWED_FIT_TYPES. Defaults to "level"
n_draws: When using the iterative method, we need to sample from the fitted plane. This
controls how many draws we do. Defaults to 25
n_iter: Maximum number of iterations before breaking out of the fitting routine. Defaults
to 10,000
convergence_abs_tol: Absolute tolerance to define convergence. Defaults to 1e-8
convergence_rel_tol: Relative tolerance to define convergence. Defaults to 1e-5
ref_idx: Index to define the zero level for all the level matching. Defaults to None, which
will use the average correction
"""
delta_mat = copy.deepcopy(delta_mat)
npix_mat = copy.deepcopy(npix_mat)
rms_mat = copy.deepcopy(rms_mat)
lin_size_mat = copy.deepcopy(lin_size_mat)
ns = delta_mat.shape[0]
# Matrix for fits that we'll actually use
use_mat = copy.deepcopy(valid_mat)
# Remove things with weights less than min_area_percent of the average weight. Use all overlaps here
avg_npix_val = np.nanmean(npix_mat[valid_mat == 1])
small_area_idx = npix_mat < self.min_area_percent * avg_npix_val
use_mat[small_area_idx] = 0
# Remove things that haven't passed the small area test
use_mat[lin_size_mat == 0] = 0
# Remove fits with RMS values some sigma above the mean. Use only good overlaps here
avg_rms_val = np.nanmean(rms_mat[np.logical_and(valid_mat == 1, use_mat == 1)])
sig_rms_val = np.nanstd(rms_mat[np.logical_and(valid_mat == 1, use_mat == 1)])
rms_idx = np.where(rms_mat > avg_rms_val + self.rms_sig_limit * sig_rms_val)
use_mat[rms_idx] = 0
# Create weight matrix
if self.weight_method == "equal":
# Weight evenly
weight = np.ones_like(use_mat)
elif self.weight_method == "npix":
# Weight by straight number of pixels
weight = 0.5 * (npix_mat + npix_mat.T)
elif self.weight_method == "rms":
# Weight by inverse variance of the fit
with warnings.catch_warnings():
warnings.simplefilter("ignore")
weight = 0.5 * (rms_mat + rms_mat.T)
weight = weight ** -2
weight[~np.isfinite(weight)] = 0
else:
raise ValueError(f"weight_method {self.weight_method} not known")
neq = 0
for i in range(ns):
for j in range(i + 1, ns):
if valid_mat[i, j] == 1 and use_mat[i, j] == 1:
neq += 1
# Create arrays for coefficients and free terms
k = np.zeros((neq, ns), dtype=float)
f = np.zeros([neq, delta_mat.shape[-1]], dtype=float)
invalid = ns * [True]
# Process intersections between the rest of the images
ieq = 0
for i in range(0, ns):
for j in range(i + 1, ns):
# Only pull out valid intersections
if valid_mat[i, j] == 1 and use_mat[i, j] == 1:
k[ieq, i] = weight[i, j]
k[ieq, j] = -weight[i, j]
for coeff in range(delta_mat.shape[-1]):
f[ieq, coeff] = weight[i, j] * delta_mat[i, j, coeff]
invalid[i] = False
invalid[j] = False
ieq += 1
rank = np.linalg.matrix_rank(k, 1.0e-12)
if rank < ns - 1:
logging.warning(
f"There are more unknown sky values ({ns}) to be solved for"
)
logging.warning(
"than there are independent equations available "
f"(matrix rank={rank})."
)
logging.warning("Sky matching (delta) values will be computed only for")
logging.warning("a subset (or more independent subsets) of input images.")
# Uses the iterative montage method to find best fits
deltas = np.zeros([ns, delta_mat.shape[-1]])
delta_mat_corr = copy.deepcopy(delta_mat)
converged = False
level_converged = False
# Set a max number of iterations to just run level matching. This is
# either 2500 or half the number of iterations if the iteration
# number is relatively small
if n_iter < 5000:
n_level = n_iter // 2
else:
n_level = 2500
iteration = 0
level_iteration = 0
# Iterate until convergence or some maximum number of iterations
while not converged and iteration < n_iter:
# We'll start off just trying to optimize the DC offsets
if fit_type in ["level+slope"] and not level_converged:
curr_fit_type = "level"
else:
curr_fit_type = copy.deepcopy(fit_type)
# Pull useful things out to dictionaries since we'll update
# everything in bulk at the end
delta_arr = {}
best_fits = {}
deltas_prev = copy.deepcopy(deltas)
for i in range(ns):
if invalid[i]:
continue
delta_arr[i] = {}
# Loop over and pull out any valid overlaps
for j in range(ns):
if i == j:
continue
if invalid[j]:
continue
# Only take fits we actually want to use
if valid_mat[i, j] == 1 and use_mat[i, j] == 1:
delta_arr[i][j] = copy.deepcopy(delta_mat_corr[i, j, :])
# For each delta value, we want to sample randomly in the (x, y) plane
weight_vals = []
x_vals = []
y_vals = []
z_vals = []
# For a starting guess, take the average of the various coefficients
delta_arr_stack = np.array([list(delta_arr[i][d]) for d in delta_arr[i]])
p0 = np.nanmean(delta_arr_stack, axis=0)
# Sample some points in the plane
x = np.random.normal(loc=0, scale=1, size=n_draws)
y = np.random.normal(loc=0, scale=1, size=n_draws)
for j in delta_arr[i]:
if invalid[j]:
continue
# This is where we're fitting, so only use the values that we've defined as good
if valid_mat[i, j] == 0 or use_mat[i, j] == 0:
continue
arr_val = copy.deepcopy(delta_arr[i][j])
if curr_fit_type == "level":
coeffs = np.array([0, 0, arr_val[-1]])
elif curr_fit_type == "level+slope":
coeffs = copy.deepcopy(arr_val)
else:
raise ValueError(f"Unknown fit type: {curr_fit_type}")
z = coeffs[0] * x + coeffs[1] * y + coeffs[2]
weight_vals.extend([weight[i, j]] * n_draws)
x_vals.extend(list(x))
y_vals.extend(list(y))
z_vals.extend(list(z))
x_vals = np.array(x_vals)
y_vals = np.array(y_vals)
z_vals = np.array(z_vals)
weight_vals = np.array(weight_vals)
if curr_fit_type == "level":
# Just get an average of the offsets
best_fit_vals = np.average(z_vals,
weights=np.sqrt(weight_vals),
keepdims=True,
)
elif curr_fit_type == "level+slope":
# Use scipy minimize to get a best fit plane
points = np.vstack((x_vals, y_vals, z_vals)).T
# Here, the error is the inverse of the weights
func = partial(plane_resid,
points=points,
err=np.sqrt(weight_vals) ** -1,
)
res = minimize(func,
p0,
method="Powell",
)
# Pull out best fit, calculate the stats and update
# which points we're fitting
best_fit_vals = copy.deepcopy(res.x)
else:
raise ValueError(f"Unknown fit type: {curr_fit_type}")
if curr_fit_type == "level":
best_fit = np.zeros(delta_mat.shape[-1])
best_fit[-1] = best_fit_vals[-1]
elif curr_fit_type == "level+slope":
best_fit = copy.deepcopy(best_fit_vals)
else:
raise ValueError(f"Fit type {curr_fit_type} is not known.")
# Factor of 2 to keep things symmetrical
best_fit /= 2
best_fits[i] = copy.deepcopy(best_fit)
for i in range(ns):
if i not in best_fits:
continue
# Add on the best fits
deltas[i, :] += best_fits[i]
# Edit the corrections, doing this for all valid fits
for i in range(ns):
if invalid[i]:
continue
if i not in best_fits:
continue
for j in range(ns):
if i == j:
continue
if invalid[j]:
continue
if use_mat[i, j] == 0:
continue
# Apply the corrections within the matrix
delta_mat_corr[i, j, :] -= best_fits[i]
delta_mat_corr[i, j, :] += best_fits[j]
# Check for convergence. If the delta values haven't
# changed within the tolerance, jump out and call it a day
has_converged = np.all(
np.isclose(
deltas,
deltas_prev,
atol=convergence_abs_tol,
rtol=convergence_rel_tol,
)
)
if has_converged:
if curr_fit_type != fit_type:
level_converged = True
else:
converged = True
log.debug(f"{curr_fit_type} matching converged after {iteration} iterations")
# Update convergences
if not level_converged:
level_iteration += 1
if not converged:
iteration += 1
if level_iteration >= n_level and not level_converged:
log.debug(f"{curr_fit_type} matching has not converged after {iteration} iterations")
level_converged = True
if iteration >= n_iter and not converged:
log.debug(f"{curr_fit_type} matching has not converged after {iteration} iterations")
converged = True
# If we don't have a selected reference index, take the average correction
if ref_idx is None:
offset_delta = np.nanmean(deltas, axis=0)
else:
offset_delta = copy.deepcopy(deltas[ref_idx, :])
# Set the reference image correction to 0, adjust all other
# corrections relative to that
for i in range(deltas.shape[0]):
deltas[i, :] -= offset_delta
# Set any invalid deltas to 0
deltas[np.asarray(invalid, dtype=bool), :] = 0
return deltas
[docs]
def get_reproject(
self,
file,
optimal_wcs,
optimal_shape,
stacked_image=False,
):
"""Reproject files, maintaining list structure
Args:
file: List or single file to reproject
optimal_wcs: WCS to reproject to
optimal_shape: output array shape for the WCS
stacked_image (bool): Whether this is a stacked image or not. Defaults to False
"""
if isinstance(file, list):
file_reproj = [
{"data": reproject_image(
i,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
do_sigma_clip=self.do_sigma_clip,
stacked_image=stacked_image,
reproject_func=self.reproject_func,
),
"err": reproject_image(
i,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
hdu_type="err",
do_sigma_clip=self.do_sigma_clip,
stacked_image=stacked_image,
reproject_func=self.reproject_func,
),
}
for i in file
]
else:
file_reproj = {
"data": reproject_image(
file,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
do_sigma_clip=self.do_sigma_clip,
stacked_image=stacked_image,
reproject_func=self.reproject_func,
),
"err": reproject_image(
file,
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
hdu_type="err",
do_sigma_clip=self.do_sigma_clip,
stacked_image=stacked_image,
reproject_func=self.reproject_func,
),
}
return file_reproj
[docs]
def parallel_get_reproject(
self,
idx,
files,
optimal_wcs,
optimal_shape,
stacked_image=False,
):
"""Light function to parallelise get_dither_reproject
Args:
idx: File idx to reproject
files: Full file list
optimal_wcs: Optimal WCS for input stack of images
optimal_shape: Optimal shape for input stack of images
stacked_image: Stacked image or not? Defaults to False
"""
dither_reproj = self.get_reproject(
file=files[idx],
optimal_wcs=optimal_wcs,
optimal_shape=optimal_shape,
stacked_image=stacked_image,
)
return idx, dither_reproj
[docs]
def get_plot_name(
self,
files1,
files2,
):
"""Make a plot name from list of files for level matching
Args:
files1: First list of files
files2: Second list of files
"""
if isinstance(files1, list):
files1_name_split = os.path.split(files1[0])[-1].split("_")
# Since these should be dither groups, blank out the dither
files1_name_split[2] = "XXXXX"
else:
files1_name_split = os.path.split(files1)[-1].split("_")
plot_to_name = "_".join(files1_name_split[:-1])
if isinstance(files2, list):
files2_name_split = os.path.split(files2[0])[-1].split("_")
# Since these should be dither groups, blank out the dither
files2_name_split[2] = "XXXXX"
else:
files2_name_split = os.path.split(files2)[-1].split("_")
plot_from_name = "_".join(files2_name_split[:-1])
plot_name = os.path.join(
self.plot_dir,
f"{plot_from_name}_to_{plot_to_name}_level_match",
)
return plot_name