Source code for romancal.tweakreg.tweakreg_step

"""
Roman pipeline step for image alignment.
"""

from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from astropy.table import Table
from roman_datamodels import datamodels as rdm
from roman_datamodels import dqflags
from stcal.tweakreg import tweakreg
from stcal.tweakreg.tweakreg import TweakregError

from romancal.assign_wcs.assign_wcs import add_s_region
from romancal.datamodels.fileio import open_dataset
from romancal.lib.save_wcs import save_wfiwcs

# LOCAL
from ..datamodels import ModelLibrary
from ..stpipe import RomanStep

if TYPE_CHECKING:
    from typing import ClassVar

DEFAULT_ABS_REFCAT = "GAIADR3_S3"

__all__ = ["TweakRegStep"]

log = logging.getLogger(__name__)


[docs] class TweakRegStep(RomanStep): """ TweakRegStep: Image alignment based on catalogs of sources from in input images. """ class_alias = "tweakreg" spec = f""" catalog_format = string(default='ascii.ecsv') # Catalog output file format catalog_path = string(default='') # Catalog output file path enforce_user_order = boolean(default=False) # Align images in user specified order? expand_refcat = boolean(default=False) # Expand reference catalog with new sources? minobj = integer(default=10) # Minimum number of objects acceptable for matching searchrad = float(default=2.0) # The search radius in arcsec for a match use2dhist = boolean(default=True) # Use 2d histogram to find initial offset? separation = float(default=1.0) # Minimum object separation in arcsec tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='general') # Fitting geometry nclip = integer(min=0, default=3) # Number of clipping iterations in fit sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units abs_refcat = string(default='{DEFAULT_ABS_REFCAT}') # Absolute reference catalog save_abs_catalog = boolean(default=False) # Write out used absolute astrometric reference catalog as a separate product abs_minobj = integer(default=10) # Minimum number of objects acceptable for matching when performing absolute astrometry abs_searchrad = float(default=6.0) # The search radius in arcsec for a match when performing absolute astrometry # We encourage setting this parameter to True. Otherwise, xoffset and yoffset will be set to zero. abs_use2dhist = boolean(default=True) # Use 2D histogram to find initial offset when performing absolute astrometry? abs_separation = float(default=1.0) # Minimum object separation in arcsec when performing absolute astrometry abs_tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec when performing absolute astrometry # Fitting geometry when performing absolute astrometry abs_fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='general') abs_nclip = integer(min=0, default=3) # Number of clipping iterations in fit when performing absolute astrometry abs_sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units when performing absolute astrometry output_use_model = boolean(default=True) # When saving use `DataModel.meta.filename` update_source_catalog_coordinates = boolean(default=False) # Update source catalog file with tweaked coordinates? vo_timeout = float(min=0, default=1200.) # VO catalog service timeout. """ reference_file_types: ClassVar = []
[docs] def process(self, dataset): images = open_dataset( dataset, update_version=self.update_version, as_library=True ) if not images: raise ValueError("Input must contain at least one image model.") log.info( f"Number of image groups to be aligned: {len(images.group_indices):d}." ) log.info("Image groups:") for name in images.group_names: log.info(f" {name}") # set the first image as reference with images: ref_image = images.borrow(0) images.shelve(ref_image, 0, modify=False) # set path where the source catalog will be saved to if len(self.catalog_path) == 0: self.catalog_path = os.getcwd() self.catalog_path = Path(self.catalog_path).as_posix() log.info(f"All source catalogs will be saved to: {self.catalog_path}") # set reference catalog name if not self.abs_refcat: self.abs_refcat = DEFAULT_ABS_REFCAT.strip().upper() if self.abs_refcat != DEFAULT_ABS_REFCAT: self.expand_refcat = True # build the catalogs for input images imcats = [] with images: for i, image_model in enumerate(images): exposure_type = image_model.meta.exposure.type if exposure_type != "WFI_IMAGE": log.info("Skipping TweakReg for spectral exposure.") image_model.meta.cal_step.tweakreg = "SKIPPED" else: source_catalog = getattr(image_model.meta, "source_catalog", None) if source_catalog is None: images.shelve(image_model, i, modify=False) raise AttributeError( "Attribute 'meta.source_catalog' is missing. " "Please either run SourceCatalogStep or provide a custom source catalog." ) try: catalog = self.get_tweakreg_catalog(source_catalog, image_model) except AttributeError as e: log.error(f"Failed to retrieve tweakreg_catalog: {e}") images.shelve(image_model, i, modify=False) raise e if len(catalog) == 0: _add_required_columns(catalog) # for empty catalogs, SourceCatalog omits xpsf & ypsf; add them # validate catalog columns if not _validate_catalog_columns(catalog): raise ValueError( "'tweakreg' source catalogs must contain a header with columns named either 'x' and 'y' or 'x_psf' and 'y_psf'. Neither were found in the catalog provided." ) catalog = tweakreg.filter_catalog_by_bounding_box( catalog, image_model.meta.wcs.bounding_box ) catalog = _filter_catalog(catalog) if self.save_abs_catalog: output_name = os.path.join( self.catalog_path, f"fit_{self.abs_refcat.lower()}_ref.ecsv" ) catalog.write( output_name, format=self.catalog_format, overwrite=True ) image_model.meta["tweakreg_catalog"] = catalog.as_array() nsources = len(catalog) log.info( f"Using {nsources} sources from {image_model.meta.filename}." if nsources else f"No sources found in {image_model.meta.filename}." ) # build image catalog # catalog name catalog_name = os.path.splitext(image_model.meta.filename)[0].strip( "_- " ) # catalog data catalog_table = Table(image_model.meta.tweakreg_catalog) catalog_table.meta["name"] = catalog_name imcat = tweakreg.construct_wcs_corrector( wcs=image_model.meta.wcs, refang=image_model.meta.wcsinfo, catalog=catalog_table, group_id=images._model_to_group_id(image_model), ) imcat.meta["model_index"] = i imcats.append(imcat) images.shelve(image_model, i) # run alignment only if it was possible to build image catalogs if len(imcats): # extract WCS correctors to use for image alignment if len(images.group_indices) > 1: try: self.do_relative_alignment(imcats) except TweakregError as e: log.warning(str(e)) try: self.do_absolute_alignment(ref_image, imcats) except TweakregError as e: log.warning(str(e)) return images # finalize step with images: for imcat in imcats: image_model = images.borrow(imcat.meta["model_index"]) image_model.meta.cal_step.tweakreg = "COMPLETE" # remove source catalog del image_model.meta["tweakreg_catalog"] # retrieve fit status and update wcs if fit is successful: if "SUCCESS" in imcat.meta.get("fit_info")["status"]: # Update/create the WCS .name attribute with information # on this astrometric fit as the only record that it was # successful: # NOTE: This .name attrib agreed upon by the JWST Cal # Working Group. # Current value is merely a place-holder based # on HST conventions. This value should also be # translated to the FITS WCSNAME keyword # IF that is what gets recorded in the archive # for end-user searches. imcat.wcs.name = f"FIT-LVL2-{self.abs_refcat}" # serialize object from tweakwcs # (typecasting numpy objects to python types so that it doesn't cause an # issue when saving datamodel to ASDF) wcs_fit_results = { k: ( v.tolist() if isinstance(v, np.ndarray | np.bool_) else v ) for k, v in imcat.meta["fit_info"].items() } # add fit results and new WCS to datamodel image_model.meta["wcs_fit_results"] = wcs_fit_results # remove unwanted keys from WCS fit results for k in [ "eff_minobj", "matched_ref_idx", "matched_input_idx", "fit_RA", "fit_DEC", "fitmask", ]: del image_model.meta["wcs_fit_results"][k] # update WCS image_model.meta.wcs = imcat.wcs # update S_REGION add_s_region(image_model) # update source catalog coordinates if requested if self.update_source_catalog_coordinates: try: self.update_catalog_coordinates( image_model.meta.source_catalog[ "tweakreg_catalog_name" ], imcat.wcs, ) except Exception as e: log.error( f"Failed to update source catalog coordinates: {e}" ) raise e images.shelve(image_model, imcat.meta["model_index"]) return images
[docs] def save_model(self, result, *args, **kwargs): if isinstance(result, ModelLibrary): save_wfiwcs(self, result, force=True) super().save_model(result, *args, **kwargs)
[docs] def update_catalog_coordinates(self, tweakreg_catalog_name, tweaked_wcs): """ Update the source catalog coordinates using the tweaked WCS while strictly preserving original file metadata. Parameters ---------- tweakreg_catalog_name : str Path to the source catalog file (in Parquet format) to be updated. tweaked_wcs : callable A WCS transformation function that takes x and y coordinates and returns updated (RA, Dec) values. Returns ------- None The function updates the catalog file in place; it does not return a value. Notes ----- The method preserves all original file metadata by reading and re-attaching it after coordinate updates. Only the coordinate columns are modified; all other data and metadata remain unchanged. """ # Read the existing catalog using PyArrow pa_table = pq.read_table(tweakreg_catalog_name) original_metadata = pa_table.schema.metadata # Determine which coordinate columns are present and update them from pixel-space # coordinates using the tweaked WCS. available_cols = set(pa_table.schema.names) # (x_col, y_col) -> (ra_col, dec_col) updates = [ ("x_centroid", "y_centroid", "ra_centroid", "dec_centroid"), ("x_centroid", "y_centroid", "ra", "dec"), ( "x_centroid_win", "y_centroid_win", "ra_centroid_win", "dec_centroid_win", ), ("x_psf", "y_psf", "ra_psf", "dec_psf"), ] updated_columns: dict[str, pa.Array] = {} for x_col, y_col, ra_col, dec_col in updates: # Only update existing columns to preserve the file schema. if ( x_col in available_cols or y_col in available_cols or ra_col in available_cols or dec_col in available_cols ): x_values = pa_table[x_col].to_numpy() y_values = pa_table[y_col].to_numpy() new_ra, new_dec = tweaked_wcs(x_values, y_values) new_ra = np.asarray(getattr(new_ra, "value", new_ra)) new_dec = np.asarray(getattr(new_dec, "value", new_dec)) # Preserve the original column types. updated_columns[ra_col] = pa.array( new_ra, type=pa_table.schema.field(ra_col).type ) updated_columns[dec_col] = pa.array( new_dec, type=pa_table.schema.field(dec_col).type ) # Create new table with updated columns # Keep all original columns, replacing only the updated ones new_columns = [] new_names = [] for i, field in enumerate(pa_table.schema): col_name = field.name if col_name in updated_columns: # Use updated column new_columns.append(updated_columns[col_name]) else: # Keep original column new_columns.append(pa_table.column(i)) new_names.append(col_name) # Create new table with original schema metadata final_table = pa.table(new_columns, names=new_names) final_table = final_table.replace_schema_metadata(original_metadata) # Write back to file pq.write_table(final_table, tweakreg_catalog_name)
[docs] def read_catalog(self, catalog_name): """ Reads a source catalog from a specified file. This function determines the format of the catalog based on the file extension: * "asdf": uses roman datamodels * "parquet": uses pyarrow * otherwise: uses astropy Table. Parameters ---------- catalog_name : str The name of the catalog file to read. Returns ------- Table The read catalog as a Table object. Raises ------ ValueError If the catalog format is unsupported. """ filetype = ( "parquet" if catalog_name.endswith("parquet") else self.catalog_format ) if catalog_name.endswith("asdf"): # leave this for now with rdm.open(catalog_name) as source_catalog_model: catalog = source_catalog_model.source_catalog else: catalog = Table.read(catalog_name, format=filetype) return catalog
[docs] def get_tweakreg_catalog(self, source_catalog, image_model): """ Retrieve the tweakreg catalog from source detection. This method checks the source detection metadata for the presence of a tweakreg catalog data or a string with its name. It returns the catalog as a Table object if either is found, or raises an error if neither is available. Parameters ---------- source_catalog : object The source catalog metadata containing catalog information. image_model : DataModel The image model associated with the source detection. Returns ------- Table The retrieved tweakreg catalog as a Table object. Raises ------ AttributeError If the required catalog information is missing from the source detection. """ twk_cat = getattr(source_catalog, "tweakreg_catalog", None) twk_cat_name = getattr(source_catalog, "tweakreg_catalog_name", None) image_name = getattr( getattr(image_model, "meta", None), "filename", "<unknown>" ) if twk_cat is not None: log.info( f"Using in-memory tweakreg catalog from meta.source_catalog.tweakreg_catalog for {image_name}." ) tweakreg_catalog = Table(np.asarray(source_catalog.tweakreg_catalog)) del image_model.meta.source_catalog["tweakreg_catalog"] return tweakreg_catalog elif twk_cat_name is not None: log.info(f"Using tweakreg catalog file '{twk_cat_name}' for {image_name}.") return self.read_catalog(source_catalog.tweakreg_catalog_name) else: raise AttributeError( "Attribute 'meta.source_catalog.tweakreg_catalog' is missing. " "Please either run SourceCatalogStep or provide a custom source catalog." )
[docs] def do_relative_alignment(self, imcats): """ Perform relative alignment of images. This method performs relative alignment with the specified parameters, including search radius, separation, and fitting geometry. Parameters ---------- imcats : list A list of image catalogs containing source information for alignment. Returns ------- None """ tweakreg.relative_align( imcats, searchrad=self.searchrad, separation=self.separation, use2dhist=self.use2dhist, tolerance=self.tolerance, xoffset=0, yoffset=0, enforce_user_order=self.enforce_user_order, expand_refcat=self.expand_refcat, minobj=self.minobj, fitgeometry=self.fitgeometry, nclip=self.nclip, sigma=self.sigma, clip_accum=True, )
[docs] def do_absolute_alignment(self, ref_image, imcats): """ Perform absolute alignment of images. This method retrieves a reference image and performs absolute alignment using the specified parameters, including reference WCS information and catalog details. It aligns the provided image catalogs to the absolute reference catalog. Parameters ---------- ref_image : DataModel The reference image used for alignment, which contains WCS information. imcats : list A list of image catalogs containing source information for alignment. Returns ------- None """ tweakreg.absolute_align( imcats, self.abs_refcat, ref_wcs=ref_image.meta.wcs, ref_wcsinfo=ref_image.meta.wcsinfo, epoch=ref_image.meta.exposure.start_time.decimalyear, abs_minobj=self.abs_minobj, abs_fitgeometry=self.abs_fitgeometry, abs_nclip=self.abs_nclip, abs_sigma=self.abs_sigma, abs_searchrad=self.abs_searchrad, abs_use2dhist=self.abs_use2dhist, abs_separation=self.abs_separation, abs_tolerance=self.abs_tolerance, save_abs_catalog=False, clip_accum=True, timeout=self.vo_timeout, )
def _validate_catalog_columns(catalog) -> bool: """ Validate the presence of required columns in the catalog. This method checks if the specified axis column exists in the catalog. If the axis is not found, it looks for a corresponding psf column and renames it if present. If neither is found, it raises an error. Parameters ---------- catalog : Table The catalog to validate, which should contain source information. Returns ------- True if all the required columns are present, False otherwise. """ for axis in ["x", "y"]: if axis not in catalog.colnames: long_axis = f"{axis}_psf" if long_axis in catalog.colnames: catalog.rename_column(long_axis, axis) else: return False return True def _add_required_columns(catalog): """ Updates the input catalog with the required columns based on the standard output from SourceCatalogStep. The centroid coordinates are always present in the standard output from SourceCatalogStep. Parameters ---------- catalog : Table The catalog to validate, which should contain source information. Returns ------- None """ catalog["x"] = catalog["x_centroid"] catalog["y"] = catalog["y_centroid"] def _filter_catalog(catalog): """ Remove flagged sources from catalog for tweakreg purposes. This presently removes only sources whose central cores are flagged DO_NOT_USE. Parameters ---------- catalog : Table The catalog from which to filter flagged sources. Returns ------- The filtered catalog """ if "warning_flags" in catalog.dtype.names: bad = (catalog["warning_flags"] & dqflags.pixel.DO_NOT_USE) != 0 catalog = catalog[~bad] return catalog