Source code for romancal.assign_wcs.assign_wcs

"""
WCS construction utilities for Roman WFI images.
"""

import logging

import gwcs.coordinate_frames as cf
import numpy as np
from astropy import coordinates as coord
from astropy import units as u
from astropy.modeling import bind_bounding_box
from astropy.modeling.models import Identity, RotationSequence3D, Scale, Shift
from gwcs.geometry import CartesianToSpherical, SphericalToCartesian
from gwcs.wcs import WCS, Step
from roman_datamodels import datamodels as rdm
from stcal.alignment.util import compute_s_region_keyword, wcs_bbox_from_shape

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


[docs] def load_wcs(input_model, reference_files=None): """Create a gWCS object and store it in ``Model.meta``. Parameters ---------- input_model : `~roman_datamodels.datamodels.WfiImage` The exposure. reference_files : dict A dict {reftype: reference_file_name} containing all reference files that apply to this exposure. Returns ------- output_model : `~roman_datamodels.ImageModel` The input image file with attached gWCS object. The input_model is modified in place. """ output_model = input_model if reference_files is not None: for ref_type, ref_file in reference_files.items(): reference_files[ref_type] = ( ref_file if ref_file not in ["N/A", ""] else None ) else: reference_files = {} # Frames detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) v2v3 = cf.Frame2D( name="v2v3", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec), ) v2v3vacorr = cf.Frame2D( name="v2v3vacorr", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec), ) world = cf.CelestialFrame(reference_frame=coord.ICRS(), name="world") # Transforms between frames distortion = _wfi_distortion(output_model, reference_files) tel2sky = v23tosky(output_model) # Compute differential velocity aberration (DVA) correction: va_corr = _dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, v3_ref=input_model.meta.wcsinfo.v3_ref, ) pipeline = [ Step(detector, distortion), Step(v2v3, va_corr), Step(v2v3vacorr, tel2sky), Step(world, None), ] wcs = WCS(pipeline) if wcs.bounding_box is None: wcs.bounding_box = wcs_bbox_from_shape(output_model.data.shape) output_model.meta["wcs"] = wcs # update S_REGION add_s_region(output_model) output_model.meta.cal_step["assign_wcs"] = "COMPLETE" return output_model
def _wfi_distortion(model, reference_files): """ Create the "detector" to "v2v3" transform for WFI Parameters ---------- model : `~roman_datamodels.datamodels.WfiImage` The data model for processing reference_files : dict A dict {reftype: reference_file_name} containing all reference files that apply to this exposure. Returns ------- The transform model """ dist = rdm.DistortionRefModel(reference_files["distortion"]) transform = dist.coordinate_distortion_transform try: bbox = transform.bounding_box.bounding_box(order="F") except NotImplementedError: # Check if the transform in the reference file has a ``bounding_box``. # If not set a ``bounding_box`` equal to the size of the image after # assembling all distortion corrections. bbox = None dist.close() bind_bounding_box( transform, wcs_bbox_from_shape(model.data.shape) if bbox is None else bbox, order="F", ) return transform def v23tosky(input_model, wrap_v2_at=180, wrap_lon_at=360): """Create the transform from telescope to sky. The transform is defined with a reference point in a Frame associated tih the telescope (V2, V3) in arcsec, the corresponding reference poiont on sky (RA_REF, DEC_REF) in deg, and the position angle at the center of the aperture, ROLL_REF in deg. Parameters ---------- input_model : `roman_datamodels.WfiImage` Roman imaging exposure data model. wrap_v2_at : float At what angle to wrap V2. [deg] wrap_lon_at : float At what angle to wrap logitude. [deg] Returns ------- model : `astropy.modeling.Model` The transform from V2,V3 to sky. """ v2_ref = input_model.meta.wcsinfo.v2_ref / 3600 v3_ref = input_model.meta.wcsinfo.v3_ref / 3600 roll_ref = input_model.meta.wcsinfo.roll_ref ra_ref = input_model.meta.wcsinfo.ra_ref dec_ref = input_model.meta.wcsinfo.dec_ref angles = np.array([v2_ref, -v3_ref, roll_ref, dec_ref, -ra_ref]) axes = "zyxyz" rot = RotationSequence3D(angles, axes_order=axes) # The sky rotation expects values in deg. # This should be removed when models work with quantities. model = ( (Scale(1 / 3600) & Scale(1 / 3600)) | SphericalToCartesian(wrap_lon_at=wrap_v2_at) | rot | CartesianToSpherical(wrap_lon_at=wrap_lon_at) ) model.name = "v23tosky" return model def _dva_corr_model(va_scale, v2_ref, v3_ref): """ Create transformation that accounts for differential velocity aberration (scale). Parameters ---------- va_scale : float, None Ratio of the apparent plate scale to the true plate scale. When ``va_scale`` is `None`, it is assumed to be identical to ``1`` and an ``astropy.modeling.models.Identity`` model will be returned. v2_ref : float, None Telescope ``v2`` coordinate of the reference point in ``arcsec``. When ``v2_ref`` is `None`, it is assumed to be identical to ``0``. v3_ref : float, None Telescope ``v3`` coordinate of the reference point in ``arcsec``. When ``v3_ref`` is `None`, it is assumed to be identical to ``0``. Returns ------- va_corr : astropy.modeling.CompoundModel, astropy.modeling.models.Identity A 2D compound model that corrects DVA. If ``va_scale`` is `None` or 1 then `astropy.modeling.models.Identity` will be returned. """ if va_scale is None or va_scale == 1: return Identity(2) if va_scale <= 0: log.warning("Given velocity aberration scale %s", va_scale) log.warning( "Velocity aberration scale must be a positive number. Setting to 1.0" ) va_scale = 1.0 va_corr = Scale(va_scale, name="dva_scale_v2") & Scale( va_scale, name="dva_scale_v3" ) if v2_ref is None: v2_ref = 0 if v3_ref is None: v3_ref = 0 if v2_ref == 0 and v3_ref == 0: return va_corr # NOTE: it is assumed that v2, v3 angles and va scale are small enough # so that for expected scale factors the issue of angle wrapping # (180 degrees) can be neglected. v2_shift = (1 - va_scale) * v2_ref v3_shift = (1 - va_scale) * v3_ref va_corr |= Shift(v2_shift, name="dva_v2_shift") & Shift( v3_shift, name="dva_v3_shift" ) va_corr.name = "DVA_Correction" return va_corr def _create_footprint(wcs, shape=None, center=False): """Calculate sky footprint Parameters ---------- wcs : `gwcs.WCS` The WCS information to get the footprint from shape : n-tuple or None Shape to use if wcs has no defined shape. center : bool If True use the center of the pixel, otherwise use the corner. Returns ------- footprint : `numpy.ndarray` The footprint. """ bbox = wcs.bounding_box if bbox is None: bbox = wcs_bbox_from_shape(shape) # footprint is an array of shape (2, 4) - i.e. 4 values for RA and 4 values for # Dec - as we are interested only in the footprint on the sky footprint = wcs.footprint(bbox, center=center, axis_type="spatial").T # take only imaging footprint footprint = footprint[:2, :] # Make sure RA values are all positive negative_ind = footprint[0] < 0 if negative_ind.any(): footprint[0][negative_ind] = 360 + footprint[0][negative_ind] footprint = footprint.T return footprint def add_s_region(model): """ Calculate the detector's footprint using ``WCS.footprint`` and save it in the ``S_REGION`` keyword Parameters ---------- model : `~roman_datamodels.datamodels.ImageModel` The data model for processing Returns ------- A formatted string representing the detector's footprint """ _update_s_region_keyword( model, _create_footprint(model.meta.wcs, shape=model.shape, center=False) ) def _update_s_region_keyword(model, footprint): s_region = compute_s_region_keyword(footprint) log.info(f"S_REGION VALUES: {s_region}") if "nan" in s_region: # do not update s_region if there are NaNs. log.info("There are NaNs in s_region, S_REGION not updated.") else: model.meta.wcsinfo.s_region = s_region log.info(f"Update S_REGION to {model.meta.wcsinfo.s_region}")