Source code for rubin_scheduler.scheduler.basis_functions.basis_functions

__all__ = (
    "BaseBasisFunction",
    "HealpixLimitedBasisFunctionMixin",
    "ConstantBasisFunction",
    "SimpleArrayBasisFunction",
    "DelayStartBasisFunction",
    "AvoidFastRevisitsBasisFunction",
    "VisitRepeatBasisFunction",
    "M5DiffBasisFunction",
    "M5DiffAtHpixBasisFunction",
    "StrictBandBasisFunction",
    "StrictFilterBasisFunction",
    "BandChangeBasisFunction",
    "FilterChangeBasisFunction",
    "SlewtimeBasisFunction",
    "CadenceEnhanceBasisFunction",
    "CadenceEnhanceTrapezoidBasisFunction",
    "AzimuthBasisFunction",
    "AzModuloBasisFunction",
    "DecModuloBasisFunction",
    "MapModuloBasisFunction",
    "SeasonCoverageBasisFunction",
    "NObsPerYearBasisFunction",
    "CadenceInSeasonBasisFunction",
    "NearSunHighAirmassBasisFunction",
    "EclipticBasisFunction",
    "VisitGap",
    "NGoodSeeingBasisFunction",
    "AvoidDirectWind",
    "BalanceVisits",
    "RewardNObsSequence",
    "BandDistBasisFunction",
    "FilterDistBasisFunction",
    "RewardRisingBasisFunction",
    "send_unused_deprecation_warning",
)

import warnings

import healpy as hp
import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord

from rubin_scheduler.scheduler import features, utils
from rubin_scheduler.scheduler.utils import IntRounded, get_current_footprint
from rubin_scheduler.skybrightness_pre import dark_m5
from rubin_scheduler.utils import DEFAULT_NSIDE, SURVEY_START_MJD, _hpid2_ra_dec


def send_unused_deprecation_warning(name):
    message = (
        f"The basis function {name} is not in use by the current "
        "baseline scheduler and may be deprecated shortly. "
        "Please contact the rubin_scheduler maintainers if "
        "this is in use elsewhere."
    )
    warnings.warn(message, FutureWarning)


[docs] class BaseBasisFunction: """Class that takes features and computes a reward function when called.""" def __init__(self, nside=DEFAULT_NSIDE, bandname=None, filtername=None, **kwargs): # Set if basis function needs to be recalculated if there is a new # observation self.update_on_newobs = True # Set if basis function needs to be recalculated if conditions # change self.update_on_mjd = True # Dict to hold all the features we want to track self.survey_features = {} # Keep track of the last time the basis function was called. # If mjd doesn't change, use cached value self.mjd_last = None self.value = 0 # list the attributes to compare to check if basis functions # are equal. self.attrs_to_compare = [] # Do we need to recalculate the basis function self.recalc = True # Basis functions don't technically all need an nside, but so # many do might as well set it here if nside is None: self.nside = utils.set_default_nside() else: self.nside = nside if filtername is not None: warnings.warn( "Use of `filtername` will be deprecated in favor of `bandname` at v4", FutureWarning ) bandname = filtername # Save filtername as a backup in case someone tries to access it self.filtername = filtername self.bandname = bandname
[docs] def add_observations_array(self, observations_array, observations_hpid): """Similar to add_observation, but for loading a whole array of observations at a time. Parameters ---------- observations_array_in : `np.array` An array of completed observations (with columns like rubin_scheduler.scheduler.utils.ObservationArray). Should be sorted by MJD. observations_hpid_in : `np.array` Same as observations_array_in, but larger and with an additional column for HEALpix id. Each observation is listed multiple times, once for every HEALpix it overlaps. """ for feature in self.survey_features: self.survey_features[feature].add_observations_array(observations_array, observations_hpid) if self.update_on_newobs: self.recalc = True
[docs] def add_observation(self, observation, indx=None): """ Parameters ---------- observation : `np.array` An array with information about the input observation indx : `np.array` The indices of the healpix map that the observation overlaps with """ for feature in self.survey_features: self.survey_features[feature].add_observation(observation, indx=indx) if self.update_on_newobs: self.recalc = True
[docs] def check_feasibility(self, conditions): """If there is logic to decide if something is feasible (e.g., only if moon is down), it can be calculated here. Helps prevent full __call__ from being called more than needed. """ return True
def _calc_value(self, conditions, **kwargs): self.value = 0 # Update the last time we had an mjd self.mjd_last = conditions.mjd + 0 self.recalc = False return self.value def __eq__(self): # XXX--to work on if we need to make a registry of basis functions. pass def __ne__(self): pass
[docs] def __call__(self, conditions, **kwargs): """ Parameters ---------- conditions : `rubin_scheduler.scheduler.features.conditions` object Object that has attributes for all the current conditions. Return a reward healpix map or a reward scalar. """ # If we are not feasible, return -inf if not self.check_feasibility(conditions): return -np.inf if self.recalc: self.value = self._calc_value(conditions, **kwargs) if self.update_on_mjd: if conditions.mjd != self.mjd_last: self.value = self._calc_value(conditions, **kwargs) return self.value
[docs] def label(self): """Create a label for this basis function. Returns ------- label : `str` A string suitable for labeling the basis function in a plot or table. """ label = self.__class__.__name__.replace("BasisFunction", "") if self.bandname is not None: label += f" {self.bandname}" label += f" @{id(self)}" return label
[docs] class HealpixLimitedBasisFunctionMixin: """A mixin to limit a basis function to a set of Healpix pixels.""" def __init__(self, hpid, *args, **kwargs): super().__init__(*args, **kwargs) self.hpid = hpid
[docs] def check_feasibility(self, conditions): """Check the feasibility of the current set of conditions. Parameters ---------- conditions : `rubin_scheduler.scheduler.features.Conditions` The conditions for which to test feasibility. Returns ------- feasibility : `bool` True if the current set of conditions is feasible, False otherwise. """ if super().check_feasibility(conditions): if self.recalc or (self.update_on_mjd and conditions.mjd != self.mjd_last): value = self._calc_value(conditions) else: value = super().value feasibility = np.nanmax(value) > -np.inf else: feasibility = False return feasibility
def _calc_value(self, conditions, all_sky=False, **kwargs): all_sky_value = super()._calc_value(conditions, **kwargs) if all_sky: return all_sky_value if np.isscalar(all_sky_value): value = all_sky_value else: assert len(all_sky_value) == hp.nside2npix(self.nside) value = all_sky_value[self.hpid] return value
[docs] class ConstantBasisFunction(BaseBasisFunction): """Just add a constant""" def __call__(self, conditions, **kwargs): return 1
[docs] class SimpleArrayBasisFunction(BaseBasisFunction): def __init__(self, value, *args, **kwargs): self.assigned_value = value super().__init__(*args, **kwargs) def _calc_value(self, conditions, **kwargs): self.value = self.assigned_value return self.value
[docs] class DelayStartBasisFunction(BaseBasisFunction): """Force things to not run before a given night. Parameters ---------- nights_delay : `float`, optional Return False until conditions.night >= nights_delay. """ def __init__(self, nights_delay=365.25 * 5): super().__init__() self.nights_delay = nights_delay
[docs] def check_feasibility(self, conditions): result = True if conditions.night < self.nights_delay: result = False return result
[docs] class BandDistBasisFunction(BaseBasisFunction): """Track band distribution, increase reward as fraction of observations in specified band drops. """ def __init__(self, bandname="r"): super(BandDistBasisFunction, self).__init__(bandname=bandname) self.survey_features = {} # Count of all the observations self.survey_features["n_obs_count_all"] = features.NObsCount(bandname=None) # Count in band self.survey_features["n_obs_count_in_filt"] = features.NObsCount(bandname=bandname) def _calc_value(self, conditions, indx=None): result = self.survey_features["n_obs_count_all"].feature / ( self.survey_features["n_obs_count_in_filt"].feature + 1 ) return result
[docs] class FilterDistBasisFunction(BandDistBasisFunction): """Deprecated version of BandDistBasisFunction""" def __init__(self, filtername="r"): warnings.warn("FilterDistBasisFunction deprecated for BandDistBasisFunction", FutureWarning) super().__init__(bandname=filtername)
[docs] class NObsPerYearBasisFunction(BaseBasisFunction): """Reward areas that have not been observed N-times in the last year Parameters ---------- bandname : `str` ('r') The band to track footprint : `np.array` Should be a HEALpix map. Values of 0 or np.nan will be ignored. n_obs : `int` (3) The number of observations to demand season : `float` (300) The amount of time to allow pass before marking a region as "behind". Default 365.25 (days). season_start_hour : `float` (-2) When to start the season relative to RA 180 degrees away from the sun (hours) season_end_hour : `float` (2) When to consider a season ending, the RA relative to the sun + 180 degrees. (hours) night_max : float (365) Set value to zero after night_max is reached (days) """ def __init__( self, bandname="r", nside=DEFAULT_NSIDE, footprint=None, n_obs=3, season=300, season_start_hour=-4.0, season_end_hour=2.0, night_max=365, filtername=None, ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super(NObsPerYearBasisFunction, self).__init__(nside=nside, bandname=bandname) self.footprint = footprint self.n_obs = n_obs self.season = season self.season_start_hour = (season_start_hour) * np.pi / 12.0 # To radians self.season_end_hour = season_end_hour * np.pi / 12.0 # To radians self.survey_features["last_n_mjds"] = features.LastNObsTimes( nside=nside, bandname=bandname, n_obs=n_obs ) self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) self.out_footprint = np.where((footprint == 0) | np.isnan(footprint)) self.night_max = night_max def _calc_value(self, conditions, indx=None): if conditions.night > self.night_max: return 0 result = self.result.copy() behind_pix = np.where((conditions.mjd - self.survey_features["last_n_mjds"].feature[0]) > self.season) result[behind_pix] = 1 # let's ramp up the weight depending on how far into the # observing season the healpix is mid_season_ra = (conditions.sun_ra + np.pi) % (2.0 * np.pi) # relative RA relative_ra = (conditions.ra - mid_season_ra) % (2.0 * np.pi) relative_ra = (self.season_end_hour - relative_ra) % (2.0 * np.pi) # ok, now relative_ra[ np.where(IntRounded(relative_ra) > IntRounded(self.season_end_hour - self.season_start_hour)) ] = 0 weight = relative_ra / (self.season_end_hour - self.season_start_hour) result *= weight # mask off anything outside the footprint result[self.out_footprint] = 0 return result
[docs] class NGoodSeeingBasisFunction(BaseBasisFunction): """Try to get N "good seeing" images each observing season. Parameters ---------- bandname : `str` Bandpass in which to count images. Default r. nside : `int` The nside of the map for the basis function. Should match survey and scheduler nside. Default None uses `set_default_nside`. seeing_fwhm_max : `float` Value to consider as "good" threshold (arcsec). Default of 0.8 arcseconds. m5_penalty_max : `float` The maximum depth loss that is considered acceptable (magnitudes), compared to the dark-sky map in this band. Default 0.5 magnitudes. n_obs_desired : `int` Number of good seeing observations to collect per season. Default 3. mjd_start : `float` The starting MJD of the survey. Default None uses `rubin_scheduler.utils.SURVEY_START_MJD`. footprint : `np.array`, (N,) Only use area where footprint > 0. Should be a HEALpix map. Default None calls `get_current_footprint()`. """ def __init__( self, bandname="r", nside=DEFAULT_NSIDE, seeing_fwhm_max=0.8, m5_penalty_max=0.5, n_obs_desired=3, mjd_start=None, footprint=None, filtername=None, ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super().__init__(nside=nside, bandname=bandname) self.seeing_fwhm_max = seeing_fwhm_max self.m5_penalty_max = m5_penalty_max self.n_obs_desired = n_obs_desired if mjd_start is None: mjd_start = SURVEY_START_MJD self.mjd_start = mjd_start self.survey_features["N_good_seeing"] = features.NObservationsCurrentSeason( bandname=self.bandname, mjd_start=self.mjd_start, seeing_fwhm_max=self.seeing_fwhm_max, m5_penalty_max=self.m5_penalty_max, nside=self.nside, ) # Set footprint to current survey footprint class if undefined. if footprint is None: footprints, labels = get_current_footprint(self.nside) footprint = footprints[self.bandname] self.footprint = footprint self.result = np.zeros(hp.nside2npix(self.nside)) self.dark_map = None def _calc_value(self, conditions, indx=None): if self.bandname is not None: if self.dark_map is None: self.dark_map = dark_m5( conditions.dec, self.bandname, conditions.site.latitude_rad, fiducial_FWHMEff=0.7 ) # Return the same kind of array (not float) regardless # of result result = self.result.copy() # Update the feature to the current time. self.survey_features["N_good_seeing"].season_update(conditions=conditions) m5_penalty = self.dark_map - conditions.m5_depth[self.bandname] potential_pixels = np.where( (m5_penalty <= self.m5_penalty_max) & (conditions.fwhm_eff[self.bandname] <= self.seeing_fwhm_max) & (self.survey_features["N_good_seeing"].feature < self.n_obs_desired) & (self.footprint > 0) )[0] result[potential_pixels] = 1 return result
def az_rel_point(azs, point_az): az_rel_moon = (azs - point_az) % (2.0 * np.pi) if isinstance(azs, np.ndarray): over = np.where(az_rel_moon > np.pi) az_rel_moon[over] = 2.0 * np.pi - az_rel_moon[over] else: if az_rel_moon > np.pi: az_rel_moon = 2.0 * np.pi - az_rel_moon return az_rel_moon
[docs] class EclipticBasisFunction(BaseBasisFunction): """Mark the area around the ecliptic""" def __init__(self, nside=DEFAULT_NSIDE, distance_to_eclip=25.0): super(EclipticBasisFunction, self).__init__(nside=nside) self.distance_to_eclip = np.radians(distance_to_eclip) ra, dec = _hpid2_ra_dec(nside, np.arange(hp.nside2npix(self.nside))) self.result = np.zeros(ra.size) coord = SkyCoord(ra=ra * u.rad, dec=dec * u.rad) eclip_lat = coord.barycentrictrueecliptic.lat.radian good = np.where(np.abs(eclip_lat) < self.distance_to_eclip) self.result[good] += 1 def __call__(self, conditions, indx=None): return self.result
[docs] class CadenceInSeasonBasisFunction(BaseBasisFunction): """Drive observations at least every N days in a given area Parameters ---------- drive_map : `np.ndarray`, (N,) A HEALpix map with values of 1 where the cadence should be driven. bandname : `str` The bands that can count. season_span : `float` How long to consider a spot "in_season" (hours). cadence : `float` How long to wait before activating the basis function (days). """ def __init__( self, drive_map, bandname="griz", season_span=2.5, cadence=2.5, nside=DEFAULT_NSIDE, filtername=None ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super(CadenceInSeasonBasisFunction, self).__init__(nside=nside, bandname=bandname) self.drive_map = drive_map self.season_span = season_span / 12.0 * np.pi # To radians self.cadence = cadence self.survey_features["last_observed"] = features.LastObserved(nside=nside, bandname=bandname) self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) def _calc_value(self, conditions, indx=None): result = self.result.copy() ra_mid_season = (conditions.sunRA + np.pi) % (2.0 * np.pi) angle_to_mid_season = np.abs(conditions.ra - ra_mid_season) over = np.where(IntRounded(angle_to_mid_season) > IntRounded(np.pi)) angle_to_mid_season[over] = 2.0 * np.pi - angle_to_mid_season[over] days_lag = conditions.mjd - self.survey_features["last_observed"].feature active_pix = np.where( (IntRounded(days_lag) >= IntRounded(self.cadence)) & (self.drive_map == 1) & (IntRounded(angle_to_mid_season) < IntRounded(self.season_span)) ) result[active_pix] = 1.0 return result
[docs] class SeasonCoverageBasisFunction(BaseBasisFunction): """Basis function to encourage N observations per observing season. Parameters ---------- bandname : `str`, optional Count observations in this band. Default 'r'. nside : `int`, optional Nside for the healpix map to use for the feature. This should match the nside of the survey and scheduler. footprint : `np.array` (N,), optional Healpix map of the footprint where one should demand coverage every season. Default None will call `get_current_footprint()`. n_per_season : `int`, optional The number of observations to attempt to gather every season. Default of 3 is suitable for first year template building. mjd_start : `float`, optional The mjd of the start of the survey (days). Default None uses `rubin_scheduler.utils.SURVEY_START_MJD`. season_frac_start : `float` Only start trying to gather observations after a season is fractionally this far along. Seasons start when the apparent position of sun is at the RA of the pixel (0) and finish when the sun returns again to this RA. The default of 0.5 means that the basis function will not start returning values until the RA reaches the peak of its season. """ def __init__( self, bandname="r", nside=DEFAULT_NSIDE, footprint=None, n_per_season=3, mjd_start=None, season_frac_start=0.5, filtername=None, ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername send_unused_deprecation_warning("SeasonCoverageBasisFunction") super().__init__(nside=nside, bandname=bandname) if footprint is None: footprints, labels = get_current_footprint(self.nside) footprint = footprints[self.bandname] self.footprint = footprint # Calculate the RA values for each spot on the footprint ra, dec = _hpid2_ra_dec(nside, np.arange(hp.nside2npix(nside))) self.ra_deg = np.degrees(ra) self.n_per_season = n_per_season if mjd_start is None: mjd_start = SURVEY_START_MJD self.mjd_start = mjd_start self.season_frac_start = season_frac_start # Track how many observations have been taken at each RA/Dec # in the current observing season (for that point on the sky). self.survey_features["n_obs_season"] = features.NObservationsCurrentSeason( bandname=bandname, nside=nside, mjd_start=mjd_start ) self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) def _calc_value(self, conditions, indx=None): result = self.result.copy() # Update the feature to the current time self.survey_features["n_obs_season"].season_update(conditions=conditions) feature = self.survey_features["n_obs_season"].feature # Get the season from the conditions object. season = conditions.season # Evaluate where there are not yet enough observations and also # that season is far enough along to require more weight. not_enough = np.where( (self.footprint > 0) & (feature < self.n_per_season) & ((IntRounded(season - np.floor(season)) > IntRounded(self.season_frac_start))) ) result[not_enough] = 1 return result
[docs] class AvoidFastRevisitsBasisFunction(BaseBasisFunction): """Marks targets as unseen if they are in a specified time window in order to avoid fast revisits. Parameters ---------- bandname : `str` or None The name of the band for this target map. Using None will match visits in any band. gap_min : `float` Minimum time for the gap (minutes). nside: `int` or None The healpix resolution. penalty_val : `float` The reward value to use for regions to penalize. Will be masked if set to np.nan (default). """ def __init__(self, bandname="r", nside=DEFAULT_NSIDE, gap_min=25.0, penalty_val=np.nan, filtername=None): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super().__init__(nside=nside, bandname=bandname) self.bandname = bandname self.penalty_val = penalty_val self.gap_min = IntRounded(gap_min / 60.0 / 24.0) self.nside = nside self.survey_features = dict() self.survey_features["Last_observed"] = features.LastObserved(bandname=bandname, nside=nside, fill=0) def _calc_value(self, conditions, indx=None): result = np.ones(hp.nside2npix(self.nside), dtype=float) diff = IntRounded(conditions.mjd - self.survey_features["Last_observed"].feature) bad = np.where(diff < self.gap_min)[0] result[bad] = self.penalty_val return result
[docs] class NearSunHighAirmassBasisFunction(BaseBasisFunction): """Reward areas on the sky at high airmass, within 90 degrees azimuth of the Sun, such as suitable for the near-sun twilight microsurvey for near- or interior-to earth asteroids. Parameters ---------- nside : `int`, optional Nside for the basis function. If None, uses `set_default_nside()`. max_airmass : `float`, oprionl The maximum airmass to try and observe (unitless). penalty : `float`, optional The value to fill in non-rewarded parts of the sky. Default np.nan, which serves to mask regions exceeding the airmass limit and more than 90 degrees azimuth toward the sun. """ def __init__(self, nside=DEFAULT_NSIDE, max_airmass=2.5, penalty=np.nan): super().__init__(nside=nside) self.max_airmass = IntRounded(max_airmass) self.result = np.empty(hp.nside2npix(self.nside)) self.result.fill(penalty) def _calc_value(self, conditions, indx=None): result = self.result.copy() valid_airmass = np.where(np.isfinite(conditions.airmass) == True)[0] good_pix = np.where( (conditions.airmass[valid_airmass] >= 1.0) & (IntRounded(conditions.airmass[valid_airmass]) < self.max_airmass) & (IntRounded(np.abs(conditions.az_to_sun[valid_airmass])) < IntRounded(np.pi / 2.0)) ) result[valid_airmass[good_pix]] = ( conditions.airmass[valid_airmass][good_pix] / self.max_airmass.initial ) return result
[docs] class VisitRepeatBasisFunction(BaseBasisFunction): """ Basis function to reward re-visiting an area on the sky. Looking for Solar System objects. Parameters ---------- gap_min : `float` (15.) Minimum time for the gap (minutes) gap_max : `float` (45.) Maximum time for a gap bandname : `str` ('r') The band(s) to count with pairs npairs : `int` (1) The number of pairs of observations to attempt to gather """ def __init__( self, gap_min=25.0, gap_max=45.0, bandname="r", nside=DEFAULT_NSIDE, npairs=1, filtername=None ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super(VisitRepeatBasisFunction, self).__init__(nside=nside, bandname=bandname) self.gap_min = IntRounded(gap_min / 60.0 / 24.0) self.gap_max = IntRounded(gap_max / 60.0 / 24.0) self.npairs = npairs self.survey_features = {} # Track the number of pairs that have been taken in a night self.survey_features["Pair_in_night"] = features.PairInNight( bandname=bandname, gap_min=gap_min, gap_max=gap_max, nside=nside ) # When was it last observed # XXX--since this feature is also in Pair_in_night, I should just # access that one! self.survey_features["Last_observed"] = features.LastObserved(bandname=bandname, nside=nside) def _calc_value(self, conditions, indx=None): result = np.zeros(hp.nside2npix(self.nside), dtype=float) if indx is None: indx = np.arange(result.size) diff = conditions.mjd - self.survey_features["Last_observed"].feature[indx] mask = np.isnan(diff) # remove NaNs from diff, but save mask so we exclude those values # later. diff[mask] = 0.0 ir_diff = IntRounded(diff) good = np.where( (ir_diff >= self.gap_min) & (ir_diff <= self.gap_max) & (self.survey_features["Pair_in_night"].feature[indx] < self.npairs) & (~mask) )[0] result[indx[good]] += 1.0 return result
[docs] class M5DiffBasisFunction(BaseBasisFunction): """Basis function based on the 5-sigma depth. Look up the best depth a healpixel achieves, and compute the limiting depth difference given current conditions Parameters ---------- bandname : `str`, optional The band to consider for visits. fiducial_FWHMEff : `float`, optional The zenith seeing to assume for "good" conditions. While the dark sky depth map simply scales with this value, picking a reasonable fiducial_FWHMEff is important because this effects the overall value and scale of the reward from the basis function. nside : `int`, optional The nside for the basis function. Default None uses `set_default_nside()`. """ def __init__(self, bandname="r", fiducial_FWHMEff=0.7, nside=DEFAULT_NSIDE, filtername=None): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super().__init__(nside=nside, bandname=bandname) # The dark sky surface brightness values self.dark_map = None self.fiducial_FWHMEff = fiducial_FWHMEff self.bandname = bandname def _calc_value(self, conditions, indx=None): if self.dark_map is None: self.dark_map = dark_m5( conditions.dec, self.bandname, conditions.site.latitude_rad, self.fiducial_FWHMEff ) # No way to get the sign on this right the first time. result = conditions.m5_depth[self.bandname] - self.dark_map return result
[docs] class M5DiffAtHpixBasisFunction(HealpixLimitedBasisFunctionMixin, M5DiffBasisFunction): pass
[docs] class StrictBandBasisFunction(BaseBasisFunction): """Remove the bonus for staying in the same band if certain conditions are met. If the moon rises/sets or twilight starts/ends, it makes a lot of sense to consider a band change. This basis function rewards if it matches the current band, the moon rises or sets, twilight starts or stops, or there has been a large gap since the last observation. Parameters ---------- time_lag : `float` (10.) If there is a gap between observations longer than this, let the band change (minutes) twi_change : `float` (-18.) The sun altitude to consider twilight starting/ending (degrees) note_free : `str` ('DD') No penalty for changing bands if the last observation note field includes `note_free` string. Useful for giving a free band change after deep drilling sequence """ def __init__(self, time_lag=10.0, bandname="r", twi_change=-18.0, note_free="DD"): super(StrictBandBasisFunction, self).__init__(bandname=bandname) self.time_lag = time_lag / 60.0 / 24.0 # Convert to days self.twi_change = np.radians(twi_change) self.survey_features = {} self.survey_features["Last_observation"] = features.LastObservation() self.note_free = note_free def _calc_value(self, conditions, **kwargs): # Did the moon set or rise since last observation? moon_changed = conditions.moon_alt * self.survey_features["Last_observation"].feature["moonAlt"] < 0 # Are we already in the band (or at start of night)? in_band = (conditions.current_band == self.bandname) | (conditions.current_band is None) # Has enough time past? time_past = IntRounded( conditions.mjd - self.survey_features["Last_observation"].feature["mjd"] ) > IntRounded(self.time_lag) # Did twilight start/end? twi_changed = (conditions.sun_alt - self.twi_change) * ( self.survey_features["Last_observation"].feature["sunAlt"] - self.twi_change ) < 0 # Did we just finish a DD sequence was_dd = self.note_free in self.survey_features["Last_observation"].feature["scheduler_note"] # Is the band mounted? mounted = self.bandname in conditions.mounted_bands if (moon_changed | in_band | time_past | twi_changed | was_dd) & mounted: result = 1.0 else: result = 0.0 return result
[docs] class StrictFilterBasisFunction(StrictBandBasisFunction): """Deprecated in favor of StrictBandBasisFunction""" def __init__(self, time_lag=10.0, filtername="r", twi_change=-18.0, note_free="DD"): warnings.warn( "StrictFilterBasisFunction deprecated in favor of StrictBandBasisFunction", FutureWarning ) super().__init__(time_lag=time_lag, bandname=filtername, twi_change=twi_change, note_free=note_free)
[docs] class BandChangeBasisFunction(BaseBasisFunction): """Reward staying in the current band.""" def __init__(self, bandname="r"): super(BandChangeBasisFunction, self).__init__(bandname=bandname) def _calc_value(self, conditions, **kwargs): if (conditions.current_band == self.bandname) | (conditions.current_band is None): result = 1.0 else: result = 0.0 return result
[docs] class FilterChangeBasisFunction(BandChangeBasisFunction): """Deprecated in favor of BandChangeBasisFunction""" def __init__(self, filtername="r"): warnings.warn( "FilterChangeBasisFunction deprecated in favor of BandChangeBasisFunction", FutureWarning ) super().__init__(bandname=filtername)
[docs] class SlewtimeBasisFunction(BaseBasisFunction): """Reward slews that take little time Parameters ---------- max_time : `float` The estimated maximum slewtime (seconds). Used to normalize so the basis function spans ~ -1-0 in reward units. Default 135 seconds corresponds to just slightly less than a band change. bandname : `str`, optional The band to check for pre-post slewtime estimates. If a slew includes a band change, other basis functions will decide on the reward, so the result here can be 0. nside : `int`, optional Nside for the basis function. Default None will use `set_default_nside()`. """ def __init__(self, max_time=135.0, bandname="r", nside=DEFAULT_NSIDE, filtername=None): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super(SlewtimeBasisFunction, self).__init__(nside=nside, bandname=bandname) self.maxtime = max_time self.nside = nside self.bandname = bandname def _calc_value(self, conditions, indx=None): # If we are in a different band, the # BandChangeBasisFunction will take it # But we can still use the MASK returned by # the slewtime map to remove inaccessible parts of the sky if conditions.current_band != self.bandname: if np.size(conditions.slewtime) > 1: result = np.where(np.isfinite(conditions.slewtime), 0, np.nan) else: result = 0 else: # Need to make sure smaller slewtime is larger reward. if np.size(conditions.slewtime) > 1: # Slewtime map can contain nans and/or # infs - mask these with nans result = np.where( np.isfinite(conditions.slewtime), -conditions.slewtime / self.maxtime, np.nan, ) else: result = -conditions.slewtime / self.maxtime return result
[docs] class CadenceEnhanceBasisFunction(BaseBasisFunction): """Drive a certain cadence Parameters ---------- bandname : `str` ('gri') The band(s) that should be grouped together supress_window : `list` of `float` The start and stop window for when observations should be repressed (days) apply_area : healpix map The area over which to try and drive the cadence. Good values as 1, no cadence drive 0. Probably works as a bool array too. """ def __init__( self, bandname="gri", nside=DEFAULT_NSIDE, supress_window=[0, 1.8], supress_val=-0.5, enhance_window=[2.1, 3.2], enhance_val=1.0, apply_area=None, filtername=None, ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super(CadenceEnhanceBasisFunction, self).__init__(nside=nside, bandname=bandname) self.supress_window = np.sort(supress_window) self.supress_val = supress_val self.enhance_window = np.sort(enhance_window) self.enhance_val = enhance_val self.survey_features = {} self.survey_features["last_observed"] = features.LastObserved(bandname=bandname) self.empty = np.zeros(hp.nside2npix(self.nside), dtype=float) # No map, try to drive the whole area if apply_area is None: self.apply_indx = np.arange(self.empty.size) else: self.apply_indx = np.where(apply_area != 0)[0] def _calc_value(self, conditions, indx=None): # copy an empty array result = self.empty.copy() if indx is not None: ind = np.intersect1d(indx, self.apply_indx) else: ind = self.apply_indx if np.size(ind) == 0: result = 0 else: mjd_diff = conditions.mjd - self.survey_features["last_observed"].feature[ind] to_supress = np.where( (IntRounded(mjd_diff) > IntRounded(self.supress_window[0])) & (IntRounded(mjd_diff) < IntRounded(self.supress_window[1])) ) result[ind[to_supress]] = self.supress_val to_enhance = np.where( (IntRounded(mjd_diff) > IntRounded(self.enhance_window[0])) & (IntRounded(mjd_diff) < IntRounded(self.enhance_window[1])) ) result[ind[to_enhance]] = self.enhance_val return result
# https://docs.astropy.org/en/stable/_modules/astropy/modeling # functional_models.html#Trapezoid1D def trapezoid(x, amplitude, x_0, width, slope): """One dimensional Trapezoid model function""" # Compute the four points where the trapezoid changes slope # x1 <= x2 <= x3 <= x4 x2 = x_0 - width / 2.0 x3 = x_0 + width / 2.0 x1 = x2 - amplitude / slope x4 = x3 + amplitude / slope result = x * 0 # Compute model values in pieces between the change points range_a = np.logical_and(x >= x1, x < x2) range_b = np.logical_and(x >= x2, x < x3) range_c = np.logical_and(x >= x3, x < x4) result[range_a] = slope * (x[range_a] - x1) result[range_b] = amplitude result[range_c] = slope * (x4 - x[range_c]) return result
[docs] class CadenceEnhanceTrapezoidBasisFunction(BaseBasisFunction): """Drive a certain cadence, like CadenceEnhanceBasisFunction but with smooth transitions Parameters ---------- bandname : `str` ('gri') The band(s) that should be grouped together XXX--fill out doc string! """ def __init__( self, bandname="gri", nside=DEFAULT_NSIDE, delay_width=2, delay_slope=2.0, delay_peak=0, delay_amp=0.5, enhance_width=3.0, enhance_slope=2.0, enhance_peak=4.0, enhance_amp=1.0, apply_area=None, season_limit=None, filtername=None, ): if filtername is not None: warnings.warn("filtername deprecated in favor of bandname", FutureWarning) bandname = filtername super(CadenceEnhanceTrapezoidBasisFunction, self).__init__(nside=nside, bandname=bandname) self.delay_width = delay_width self.delay_slope = delay_slope self.delay_peak = delay_peak self.delay_amp = delay_amp self.enhance_width = enhance_width self.enhance_slope = enhance_slope self.enhance_peak = enhance_peak self.enhance_amp = enhance_amp self.season_limit = season_limit / 12 * np.pi # To radians self.survey_features = {} self.survey_features["last_observed"] = features.LastObserved(bandname=bandname) self.empty = np.zeros(hp.nside2npix(self.nside), dtype=float) # No map, try to drive the whole area if apply_area is None: self.apply_indx = np.arange(self.empty.size) else: self.apply_indx = np.where(apply_area != 0)[0] def suppress_enhance(self, x): result = x * 0 result -= trapezoid(x, self.delay_amp, self.delay_peak, self.delay_width, self.delay_slope) result += trapezoid( x, self.enhance_amp, self.enhance_peak, self.enhance_width, self.enhance_slope, ) return result def season_len(self, conditions): ra_mid_season = (conditions.sunRA + np.pi) % (2.0 * np.pi) angle_to_mid_season = np.abs(conditions.ra - ra_mid_season) over = np.where(IntRounded(angle_to_mid_season) > IntRounded(np.pi)) angle_to_mid_season[over] = 2.0 * np.pi - angle_to_mid_season[over] return angle_to_mid_season def _calc_value(self, conditions, indx=None): # copy an empty array result = self.empty.copy() if indx is not None: ind = np.intersect1d(indx, self.apply_indx) else: ind = self.apply_indx if np.size(ind) == 0: result = 0 else: mjd_diff = conditions.mjd - self.survey_features["last_observed"].feature[ind] result[ind] += self.suppress_enhance(mjd_diff) if self.season_limit is not None: radians_to_midseason = self.season_len(conditions) outside_season = np.where(radians_to_midseason > self.season_limit) result[outside_season] = 0 return result
[docs] class AzimuthBasisFunction(BaseBasisFunction): """Reward staying in the same azimuth range. Possibly better than using slewtime, especially when selecting a large area of sky. """ def __init__(self, nside=DEFAULT_NSIDE): send_unused_deprecation_warning("AzimuthBasisFunction") super(AzimuthBasisFunction, self).__init__(nside=nside) def _calc_value(self, conditions, indx=None): az_dist = conditions.az - conditions.tel_az az_dist = az_dist % (2.0 * np.pi) over = np.where(az_dist > np.pi) az_dist[over] = 2.0 * np.pi - az_dist[over] # Normalize sp between 0 and 1 result = az_dist / np.pi return result
[docs] class AzModuloBasisFunction(BaseBasisFunction): """Try to replicate the Rothchild et al cadence forcing by only observing on limited az ranges per night. Parameters ---------- az_limits : `list` of `float` pairs (None) The azimuth limits (degrees) to use. """ def __init__(self, nside=DEFAULT_NSIDE, az_limits=None, out_of_bounds_val=-1.0): super(AzModuloBasisFunction, self).__init__(nside=nside) self.result = np.ones(hp.nside2npix(self.nside)) if az_limits is None: spread = 100.0 / 2.0 self.az_limits = np.radians( [ [360 - spread, spread], [90.0 - spread, 90.0 + spread], [180.0 - spread, 180.0 + spread], ] ) else: self.az_limits = np.radians(az_limits) self.mod_val = len(self.az_limits) self.out_of_bounds_val = out_of_bounds_val def _calc_value(self, conditions, indx=None): result = self.result.copy() az_lim = self.az_limits[np.max(conditions.night) % self.mod_val] if az_lim[0] < az_lim[1]: out_pix = np.where( (IntRounded(conditions.az) < IntRounded(az_lim[0])) | (IntRounded(conditions.az) > IntRounded(az_lim[1])) ) else: out_pix = np.where( (IntRounded(conditions.az) < IntRounded(az_lim[0])) | (IntRounded(conditions.az) > IntRounded(az_lim[1])) )[0] result[out_pix] = self.out_of_bounds_val return result
[docs] class DecModuloBasisFunction(BaseBasisFunction): """Emphasize dec bands on a nightly varying basis Parameters ---------- dec_limits : `list` of `float` pairs (None) The azimuth limits (degrees) to use. """ def __init__(self, nside=DEFAULT_NSIDE, dec_limits=None, out_of_bounds_val=-1.0): super(DecModuloBasisFunction, self).__init__(nside=nside) npix = hp.nside2npix(nside) hpids = np.arange(npix) ra, dec = _hpid2_ra_dec(nside, hpids) self.results = [] if dec_limits is None: self.dec_limits = np.radians([[-90.0, -32.8], [-32.8, -12.0], [-12.0, 35.0]]) else: self.dec_limits = np.radians(dec_limits) self.mod_val = len(self.dec_limits) self.out_of_bounds_val = out_of_bounds_val for limits in self.dec_limits: good = np.where((dec >= limits[0]) & (dec < limits[1]))[0] tmp = np.zeros(npix) tmp[good] = 1 self.results.append(tmp) def _calc_value(self, conditions, indx=None): night_index = np.max(conditions.night % self.mod_val) result = self.results[night_index] return result
[docs] class MapModuloBasisFunction(BaseBasisFunction): """Similar to Dec_modulo, but now use input masks Parameters ---------- inmaps : `list` of hp arrays """ def __init__(self, inmaps): nside = hp.npix2nside(np.size(inmaps[0])) super(MapModuloBasisFunction, self).__init__(nside=nside) self.maps = inmaps self.mod_val = len(inmaps) def _calc_value(self, conditions, indx=None): indx = np.max(conditions.night % self.mod_val) result = self.maps[indx] return result
[docs] class VisitGap(BaseBasisFunction): """Basis function to create a visit gap based on the survey note field. Parameters ---------- note : `str` Value of the observation "scheduler_note" field to be masked. band_names : list [str], optional List of band names that will be considered when evaluating if the gap has passed. gap_min : float (optional) Time gap (default=25, in minutes). penalty_val : float or np.nan Value of the penalty to apply (default is np.nan). Notes ----- When a list of bands is provided, all bands must be observed before the gap requirement will be activated, and once activated, only observations in these bands will be evaluated in context of whether the last observation was at least gap in the past. """ def __init__(self, note, band_names=None, gap_min=25.0, penalty_val=np.nan, filter_names=None): if filter_names is not None: warnings.warn("filter_names deprecated in favor of band_names", FutureWarning) band_names = filter_names super().__init__() self.penalty_val = penalty_val self.gap = gap_min / 60.0 / 24.0 self.band_names = band_names self.survey_features = dict() if self.band_names is not None: for bandname in self.band_names: self.survey_features[f"LastObservationMjd::{bandname}"] = features.LastObservationMjd( note=note, bandname=bandname ) else: self.survey_features["LastObservationMjd"] = features.LastObservationMjd(scheduler_note=note)
[docs] def check_feasibility(self, conditions): notes_last_observed = [last_observed.feature for last_observed in self.survey_features.values()] if any([last_observed is None for last_observed in notes_last_observed]): return True after_gap = [conditions.mjd - last_observed > self.gap for last_observed in notes_last_observed] return all(after_gap)
def _calc_value(self, conditions, indx=None): return 1.0 if self.check_feasibility(conditions) else self.penalty_val
[docs] class AvoidDirectWind(BaseBasisFunction): """Mask the sky where the wind pressure exceeds `wind_speed_maximum`. Parameters ---------- wind_speed_maximum : `float`, optional Wind speed to mark regions as unobservable (in m/s). nside : `int`, optional The nside for the basis function. Default None uses `set_default_nside()`. """ def __init__(self, wind_speed_maximum=20.0, nside=DEFAULT_NSIDE): super().__init__(nside=nside) self.wind_speed_maximum = wind_speed_maximum def _calc_value(self, conditions, indx=None): reward_map = np.zeros(hp.nside2npix(self.nside)) if conditions.wind_speed is None or conditions.wind_direction is None: return reward_map wind_pressure = conditions.wind_speed * np.cos(conditions.az - conditions.wind_direction) reward_map -= wind_pressure**2.0 mask = wind_pressure > self.wind_speed_maximum reward_map[mask] = np.nan return reward_map
[docs] class BalanceVisits(BaseBasisFunction): """Balance visits across multiple surveys. Parameters ---------- nobs_reference : `int` Expected number of observations across all interested surveys. note_survey : `str` Note value for the current survey. note_interest : `str` Substring with the name of interested surveys to be accounted. nside : `int` Healpix map resolution. Notes ----- This basis function is designed to balance the reward of a group of surveys, such that the group get a reward boost based on the required collective number of observations. For example, if you have 3 surveys (e.g. SURVEY_A_REGION_1, SURVEY_A_REGION_2, SURVEY_A_REGION_3), when one of them is observed once (SURVEY_A_REGION_1) they all get a small reward boost proportional to the collective number of observations (`nobs_reference`). Further observations of SURVEY_A_REGION_1 would now cause the other surveys to gain a reward boost in relative to it. """ def __init__(self, nobs_reference, note_survey, note_interest, nside=DEFAULT_NSIDE): super().__init__(nside=nside) self.nobs_reference = nobs_reference self.survey_features = {} self.survey_features["n_obs_survey"] = features.NObsCount(scheduler_note=note_survey) self.survey_features["n_obs_survey_interest"] = features.NObsCount(scheduler_note=note_interest) def _calc_value(self, conditions, indx=None): return (1 + np.floor(self.survey_features["n_obs_survey_interest"].feature / self.nobs_reference)) / ( self.survey_features["n_obs_survey"].feature if self.survey_features["n_obs_survey"].feature > 0 else 1 )
[docs] class RewardNObsSequence(BaseBasisFunction): """Reward taking a sequence of observations. Parameters ---------- n_obs_survey : `int` Number of observations to reward. note_survey : `str` The value of the observation note, to take into account. nside : `int`, optional Healpix map resolution (ignored). Notes ----- This basis function is useful when a survey is composed of more than one observation (e.g. in different bands) and one wants to make sure they are all taken together. """ def __init__(self, n_obs_survey, note_survey, nside=DEFAULT_NSIDE): super().__init__(nside=nside) self.n_obs_survey = n_obs_survey self.survey_features = {} self.survey_features["n_obs_survey"] = features.NObsCount(scheduler_note=note_survey) def _calc_value(self, conditions, indx=None): return self.survey_features["n_obs_survey"].feature % self.n_obs_survey
[docs] class RewardRisingBasisFunction(BaseBasisFunction): """Reward parts of the sky that are rising. Optionally, mask out parts of the sky that are not rising. This produces a reward that increases as the field rises toward zenith, then abruptly falls as the field passes zenith. Negative hour angles (or hour angles > 180 degrees) indicate a rising point on the sky. Parameters ---------- slope : `float` Sets the 'slope' of how fast the basis function value changes with hour angle. penalty_val : `float` or `np.nan`, optional Sets the value for the part of the sky which is not rising (hour angle between 0 and 180). Using a value of np.nan will mask this region of sky, a value of 0 will just make this non-rewarding. nside : `int` or None, optional Nside for the healpix map, default of None uses scheduler default. """ def __init__(self, slope=0.1, penalty_val=0, nside=DEFAULT_NSIDE): super().__init__(nside=nside) self.slope = slope self.penalty_val = penalty_val # Probably not needed
[docs] def check_feasibility(self, conditions): return True
def _calc_value(self, conditions, indx=None): # HA should be available in the conditions object value = self.slope * conditions.HA past_zenith = np.where(conditions.HA < np.pi) value[past_zenith] = self.penalty_val return value