Source code for pjpipe.download.download_step

import gc
import glob
import logging
import multiprocessing as mp
import os
import shutil
import warnings
from functools import partial

import astropy.units as u
import numpy as np
from astropy.table import unique, vstack
from astroquery.exceptions import NoResultsWarning
from astroquery.mast import Observations
from requests.exceptions import ConnectionError, ChunkedEncodingError
from stdatamodels.jwst import datamodels
from tqdm import tqdm

log = logging.getLogger("stpipe")
log.addHandler(logging.NullHandler())


def download(
    observations,
    products,
    max_retries=5,
):
    """Wrap around astroquery download with retry option

    Args:
        observations: astroquery observation instance
        products: List of products to download
        max_retries (int): Maximum number of retries before failing.
            Defaults to 5
    """

    for product in products:
        retry = 0
        success = False

        while retry < max_retries and not success:
            try:
                observations.download_products(product)
                success = True
            except (ConnectionResetError, ConnectionError, ChunkedEncodingError):
                retry += 1

            if retry == max_retries:
                raise Warning("Max retries exceeded!")

    return True


def parallel_verify_integrity(
    file,
):
    """Wrapper to parallelise checking file integrity

    Args:
        file: File to check
    """

    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("error")
            im = datamodels.open(file)
        del im
        gc.collect()
        return True
    except:
        gc.collect()
        return file


warnings.simplefilter("error", NoResultsWarning)


[docs] class DownloadStep: def __init__( self, target=None, prop_id=None, download_dir=None, procs=1, radius=None, telescope="JWST", instrument_name=None, dataproduct_type=None, calib_level=None, extension=None, do_filter=True, breakout_targets=False, product_type=None, filter_gs=True, login=False, api_key=None, overwrite=False, ): """Query and download data from MAST Args: target: Target name. Defaults to None, which will use the proposal ID to figure out names prop_id: Proposal identifier. Defaults to None, but at least one of `target`, `prop_id` needs to be set download_dir: Where to download to. Defaults to current directory if not set procs: Number of processes to run in parallel. Defaults to 1 radius: Radius to search from centre of target in arcmin. Defaults to None, which uses astroquery default telescope: Telescope to use. Defaults to 'JWST' instrument_name: Something like 'NIRCAM/IMAGE', to filter down what products you want. Defaults to None (no filtering) dataproduct_type: These are things like 'cube', 'image' etc. Defaults to None (no filtering) calib_level: Calibration levels to filter. Defaults to None (no filtering) extension: File extensions to filter by. Defaults to None (no filtering) do_filter: Whether to filter dataproduct_type/calib_level/extension. Defaults to True breakout_targets: If targets is None, this will either download the whole proposal to a single directory (False), or breakout each target name to its own directory (True). Defaults to False product_type: Product types to download. Defaults to None (no filtering) filter_gs: Whether to filter guide stars from the observations. Defaults to True login: For proprietary data, set to True. Defaults to False api_key: If login is True, supply API key here overwrite: If True, overwrites existing data. Defaults to False """ if not target and not prop_id: raise Warning("Either one of target, prop_id should be specified!") if download_dir is None: download_dir = os.getcwd() if calib_level is None: calib_level = [ 1, 2, 3, ] if product_type is None: product_type = [ "SCIENCE", "PREVIEW", "INFO", "AUXILIARY", ] # Make sure the radius is in arcmin if not isinstance(radius, u.Quantity) and radius is not None: radius = radius * u.arcmin if isinstance(telescope, str): telescope = [telescope] if isinstance(prop_id, str): prop_id = [prop_id] if isinstance(instrument_name, str): instrument_name = [instrument_name] if isinstance(dataproduct_type, str): dataproduct_type = [dataproduct_type] self.target = target self.download_dir = download_dir self.procs = procs self.radius = radius self.telescope = telescope self.prop_id = prop_id self.instrument_name = instrument_name self.dataproduct_type = dataproduct_type self.calib_level = calib_level self.extension = extension self.do_filter = do_filter self.breakout_targets = breakout_targets self.filter_gs = filter_gs self.login = login self.api_key = api_key self.product_type = product_type self.overwrite = overwrite self.obs_list = None self.observations = Observations()
[docs] def do_step(self): """Run download step""" step_complete_file = os.path.join( self.download_dir, "download_step_complete.txt" ) cwd = os.getcwd() if not os.path.exists(self.download_dir): os.makedirs(self.download_dir) os.chdir(self.download_dir) if self.overwrite: shutil.rmtree("mastDownload") if os.path.exists(step_complete_file): log.info("Step already run") os.chdir(cwd) return True if self.login: if self.api_key is None: self.api_key = input("Supply an API key: ") self.observations.login(token=self.api_key) self.run_archive_query() if self.obs_list is None: shutil.rmtree(self.download_dir) os.chdir(cwd) return False success = self.run_download() if not success: os.chdir(cwd) return False # Check that all the files work now, else stop to rerun success = self.verify_integrity() if not success: raise Warning( "Integrity verification failed. Files will have been deleted, so run again" ) # Make a file to let us know we've already downloaded # these files with open(step_complete_file, "w+") as f: f.close() os.chdir(cwd) return True
[docs] def run_archive_query(self): """Query archive, trimming down as requested on instrument etc.""" log.info(f"Downloading to: {self.download_dir}") log.info(f"Beginning archive query:") log.info(f"-> Target: {self.target}") log.info(f"-> Telescope: {self.telescope}") log.info(f"-> Proposal ID: {self.prop_id}") log.info(f"-> Instrument name: {self.instrument_name}") log.info(f"-> Data product type: {self.dataproduct_type}") if self.target is None: # If we don't have a target, fall back to just querying the proposal ID self.obs_list = self.observations.query_criteria(proposal_id=self.prop_id) else: # Query by target if self.radius is None: self.obs_list = self.observations.query_object(self.target) else: self.obs_list = self.observations.query_object( self.target, radius=self.radius ) if np.all(self.obs_list["calib_level"] < 0): log.warning("No available data") self.obs_list = None return False self.obs_list = self.obs_list[self.obs_list["calib_level"] >= 0] # Filter down by various options if self.telescope is not None: idx = [obs["obs_collection"] in self.telescope for obs in self.obs_list] self.obs_list = self.obs_list[idx] if self.prop_id is not None: idx = [obs["proposal_id"] in self.prop_id for obs in self.obs_list] self.obs_list = self.obs_list[idx] if self.instrument_name is not None: idx = [ obs["instrument_name"] in self.instrument_name for obs in self.obs_list ] self.obs_list = self.obs_list[idx] if self.dataproduct_type is not None: idx = [ obs["dataproduct_type"] in self.dataproduct_type for obs in self.obs_list ] self.obs_list = self.obs_list[idx] if len(self.obs_list) == 0: log.warning("No available data") self.obs_list = None return False return True
[docs] def run_download( self, max_retries=5, ): """Download a list of observations Args: max_retries (int): Maximum number of retries before failing. Defaults to 5 """ if self.target is None and self.breakout_targets: # Break out by targets targets = np.unique(self.obs_list["target_name"]) for target in targets: obs_list = self.obs_list[self.obs_list["target_name"] == target] target = target.lower() target = target.replace(" ", "_") target_dir = os.path.join(self.download_dir, target) if not os.path.exists(target_dir): os.makedirs(target_dir) os.chdir(target_dir) products = self.get_products(obs_list) success = False n_retries = 0 downloaded_files = glob.glob( os.path.join( "*", "*", "*", "*", ) ) n_files = len(downloaded_files) log.info(f"Downloading {len(products)} files") if n_files == len(products): log.info(f"All files already downloaded") os.chdir(self.download_dir) continue while not success and n_retries < max_retries: download(self.observations, products) n_retries += 1 downloaded_files = glob.glob( os.path.join( "*", "*", "*", "*", ) ) n_files = len(downloaded_files) if n_files == len(products): success = True os.chdir(self.download_dir) if n_retries == max_retries: raise Warning("Max retries exceeded!") else: products = self.get_products(self.obs_list) n_retries = 0 downloaded_files = glob.glob( os.path.join( "*", "*", "*", "*", ) ) n_files = len(downloaded_files) log.info(f"Downloading {len(products)} files") if n_files == len(products): log.info(f"All files already downloaded") return True while n_retries < max_retries: download(self.observations, products) n_retries += 1 downloaded_files = glob.glob( os.path.join( "*", "*", "*", "*", ) ) n_files = len(downloaded_files) if n_files == len(products): return True raise Warning("Max retries exceeded!") return True
[docs] def get_products( self, obs_list, ): """Get products from an observations list Args: obs_list: List of observations to get products from """ # Flatten down all the observations products = [] log.info("Getting observations") for obs in tqdm( obs_list, ascii=True, ): try: product_list = self.observations.get_product_list(obs) products.append(product_list) except NoResultsWarning: log.warning(f"Data not available for {obs['obs_id']}") continue products = vstack(products) log.info(f"Found a total {len(products)} files. Performing cuts") # Filter out guide stars if requested if self.filter_gs: mask = np.char.find(products["dataURI"], "_gs-") == -1 products = products[mask] # Perform filtering if requested if self.do_filter: products = self.observations.filter_products( products, calib_level=self.calib_level, productType=self.product_type, extension=self.extension, ) # Finally, remove duplicates and sort if "dataURI" in products.colnames: products = unique( products, keys="dataURI", ) products.sort("dataURI") return products
[docs] def verify_integrity(self): """Check files have downloaded successfully""" log.info("Verifying file integrity") downloaded_files = glob.glob( os.path.join(self.download_dir, "*", "*", "*", "*.fits") ) downloaded_files.sort() # Ensure we're not wasting processes procs = np.nanmin([self.procs, len(downloaded_files)]) bad_files = [] with mp.get_context("fork").Pool(procs) as pool: for bad_file in tqdm( pool.imap_unordered( partial( parallel_verify_integrity, ), downloaded_files, ), ascii=True, desc="Verifying integrity", total=len(downloaded_files), ): bad_files.append(bad_file) pool.close() pool.join() gc.collect() # Filter out the good files (Trues) bad_files = [file for file in bad_files if file is not True] if len(bad_files) == 0: log.info("All files verified successfully") return True else: log.warning(f"Found {len(bad_files)} bad files, will remove:") for file in bad_files: log.warning(f"-> {file}") os.remove(file) return False