Source code for quasarnp.io

"""A module containing facilities for loading data and QuasarNP models.

Methods are provided for separately loading a QuasarNP model from a weights
file as well as loading DESI data, both by exposure number and by directory.
A method is also provided to load a DESI coadd file. Two legacy methods are
provided to interop with SDSS data: one to load a truth table and one to
load a SDSS data file.
"""
import json
from pathlib import Path

import fitsio
import h5py
import numpy as np

from .model import QuasarNP
from .utils import rebin, renormalize, nbins, nbins_linear, wave, linear_wave


[docs] def load_file(filename): """Load a weights file as a dictionary. Parameters ---------- filename : str The name of the weights file. Returns ------- result : dict Dictionary that maps layer names to layer weights. config_dict : dict Dictionary of model configuration options including padding mode. w_grid : numpy.ndarray Wavelength grid used to train this network. """ result = {} with h5py.File(filename, "r") as f: m_weights = f['model_weights'] m_config = json.loads(f.attrs["model_config"]) try: w_grid = f["model_grid"][:] except KeyError: print("Model grid not found in file, defaulting to logarithmic") w_grid = wave # Some versions of TF/Keras are 1 indexed and so bn layers start # at batch_normalization_1. Some versions are 0 indexed and start at # batch_normalization. Former is easer to account for in # model object since it is of consistent form. Thus we modify names # to match that syntax. inc = False if "batch_normalization" in m_weights: inc = True for k1, v1 in m_weights.items(): if len(v1) == 0: if k1 == "lambda": result[k1] = [] continue a = v1[k1] data_dict = {} for k2, v2 in a.items(): # This :-2 strips off the :0 at the end of weights names. data_dict[k2[:-2]] = v2[()] # Handles the 0 vs 1 indexed batch_norm (see note above) name = k1 if "batch_normalization" in name: if name == "batch_normalization": name = "batch_normalization_1" elif inc: name = name[:-1] + str(int(name[-1:]) + 1) result[name] = data_dict names = [m_config["config"]["layers"][i]["config"]["name"] for i in range(len(m_config["config"]["layers"]))] conv_layers = [k for k in result.keys() if k.startswith("conv")] config_dict = {} for l in conv_layers: idx = names.index(l) temp = {} for k in ["padding", "strides"]: temp[k] = m_config["config"]["layers"][idx]["config"][k] config_dict[l] = temp return result, config_dict, w_grid
[docs] def load_model(filename): """Load a weights file and return a callable model object. Parameters ---------- filename : str The name of the weights file. Returns ------- QuasarNP Callable QuasarNP model with the weights provided by `filename`. w_grid : numpy.ndarray Wavelength grid used to train this network. """ db, config, w_grid = load_file(filename) nlayers = len([k for k in db.keys() if k.startswith("conv")]) if "lambda" in db: return QuasarNP(db, rescale=True, nlayers=nlayers, config_dict=config), w_grid else: return QuasarNP(db, nlayers=nlayers, config_dict=config), w_grid
[docs] def read_truth(fi): """ Read a list of truth files and return a dictionary of truth values. This is a legacy function ported from QuasarNet, and is designed to load SDSS data files to generate a truth table. Parameters ---------- fi : list of str List of file names of truth files. Returns ------- dict Dictionary that maps `thing_id` to truth metadata. """ class metadata: pass cols = ['Z_VI', 'PLATE', 'MJD', 'FIBERID', 'CLASS_PERSON', 'Z_CONF_PERSON', 'BAL_FLAG_VI', 'BI_CIV'] truth = {} for f in fi: h = fitsio.FITS(f) tids = h[1]['THING_ID'][:] cols_dict = {c.lower(): h[1][c][:] for c in cols} h.close() for i, t in enumerate(tids): m = metadata() for c in cols_dict: setattr(m, c, cols_dict[c][i]) truth[t] = m return truth
[docs] def read_data(fi, truth=None, z_lim=2.1, return_pmf=False, nspec=None): """Read data from input file. This is a legacy function ported from QuasarNet, and is designed to load SDSS data files. Returns a tuple containing (tids, X, Y, z, bal) if `return_pmf` is `False`, otherwise returns a tuple containing (tids, X, Y, z, bal, plate, mjd, fid). Parameters ---------- fi : list of str List of data files to load. truth : dict, optional Dictionary that maps `thing_id`` to truth metadata. z_lim : float, optional Redshift to use when applying a z-cut. Defaults to 2.1. return_pmf : bool, optional Whether or not to return the `plate`, `mjd` and `fiberid`. Defaults to False. nspec : int, optional Number of spectra to read. Defaults to None (all spectra) Returns ------- tids : list of float A list of `thing_id`. X : numpy.ndarray Renormalized and rebinned spectra. Y : numpy.ndarray Classification vector of shape (`nqso`, 5) with the following entries: STAR = (1,0,0,0,0), GAL = (0,1,0,0,0) QSO_LZ = (0,0,1,0,0), QSO_HZ = (0,0,0,1,0) BAD = (0,0,0,0,1) z : numpy.ndarray Array of redshifts. bal : numpy.ndarray Truth array indicating whether each QSO is a BAL QSO or not. Each element is set to `1` if True or `0` if False. plate : numpy.ndarray Array of plate ids. Only returned when `return_pmf` is True. mjd : numpy.ndarray Array of mean julien dates. Only returned when `return_pmf` is True. fid : float Array of fiber ids. Only returned when `return_pmf` is True. """ tids = [] X = [] Y = [] z = [] bal = [] if return_pmf: plate = [] mjd = [] fid = [] for f in fi: print('INFO: reading data from {}'.format(f)) h = fitsio.FITS(f) if nspec is None: nspec = h[1].get_nrows() aux_tids = h[1]['TARGETID'][:nspec].astype(int) # Remove thing_id == -1 or not in sdrq w = (aux_tids != -1) if truth is not None: w_in_truth = np.in1d(aux_tids, list(truth.keys())) print((f"INFO: Removing {(~w_in_truth).sum()}" " spectra missing in truth"), flush=True) w &= w_in_truth aux_tids = aux_tids[w] aux_X = h[0][:nspec, :] aux_X = aux_X[w] if return_pmf: aux_plate = h[1]['PLATE'][:][w] aux_mjd = h[1]['MJD'][:][w] aux_fid = h[1]['FIBERID'][:][w] plate += list(aux_plate) mjd += list(aux_mjd) fid += list(aux_fid) X.append(aux_X) tids.append(aux_tids) print(f"INFO: Found {aux_tids.shape} spectra in file {f}") tids = np.concatenate(tids) X = np.concatenate(X) if return_pmf: plate = np.array(plate) mjd = np.array(mjd) fid = np.array(fid) we = X[:, 443:] w = we.sum(axis=1) == 0 print("INFO: removing {} spectra with zero weights".format(w.sum())) X = X[~w] tids = tids[~w] if return_pmf: plate = plate[~w] mjd = mjd[~w] fid = fid[~w] mdata = np.average(X[:, :443], weights=X[:, 443:], axis=1) sdata = np.average((X[:, :443] - mdata[:, None])**2, weights=X[:, 443:], axis=1) sdata = np.sqrt(sdata) w = sdata == 0 print("INFO: removing {} spectra with zero flux".format(w.sum())) X = X[~w] tids = tids[~w] mdata = mdata[~w] sdata = sdata[~w] if return_pmf: plate = plate[~w] mjd = mjd[~w] fid = fid[~w] X = X[:, :443] - mdata[:, None] X /= sdata[:, None] if truth is None: if return_pmf: return tids, X, plate, mjd, fid else: return tids, X # Remove zconf == 0 (not inspected) observed = [(truth[t].class_person > 0) or (truth[t].z_conf_person > 0) for t in tids] observed = np.array(observed, dtype=bool) tids = tids[observed] X = X[observed] if return_pmf: plate = plate[observed] mjd = mjd[observed] fid = fid[observed] # Fill redshifts z = np.zeros(X.shape[0]) z[:] = [truth[t].z_vi for t in tids] # Fill bal bal = np.zeros(X.shape[0]) bal[:] = [(truth[t].bal_flag_vi * (truth[t].bi_civ > 0)) - (not truth[t].bal_flag_vi) * (truth[t].bi_civ == 0) for t in tids] # Fill classes # Classes: 0 = STAR, 1=GALAXY, 2=QSO_LZ, 3=QSO_HZ, 4=BAD (zconf != 3) nclasses = 5 sdrq_class = np.array([truth[t].class_person for t in tids]) z_conf = np.array([truth[t].z_conf_person for t in tids]) Y = np.zeros((X.shape[0], nclasses)) # STAR w = (sdrq_class == 1) & (z_conf == 3) Y[w, 0] = 1 # GALAXY w = (sdrq_class == 4) & (z_conf == 3) Y[w, 1] = 1 # QSO_LZ w = ((sdrq_class == 3) | (sdrq_class == 30)) & (z < z_lim) & (z_conf == 3) Y[w, 2] = 1 # QSO_HZ w = ((sdrq_class == 3) | (sdrq_class == 30)) & (z >= z_lim) & (z_conf == 3) Y[w, 3] = 1 # BAD w = z_conf != 3 Y[w, 4] = 1 # Check that all spectra have exactly one classification assert (Y.sum(axis=1).min() == 1) and (Y.sum(axis=1).max() == 1) if return_pmf: return tids, X, Y, z, bal, plate, mjd, fid return tids, X, Y, z, bal
# DESI Related IO below this point ###############################################################################
[docs] def load_desi_exposure(dir_name, spec_number, fibers=np.ones(500, dtype="bool"), out_grid=wave): """Load and renormalize a raw DESI spectrographic exposure. This method will load B, R and Z cframe files in sequence. First, spectra are rebinned to the QuasarNet wavelength grid. Rebinned spectra are divided by the rebinned IVAR to reweight the spectra. Next, rebinned spectra are normalized by subtracting the weighted mean of the spectra and then dividing the resultant spectra by its weighted rms. The rebinned IVAR is used for weighting. Any spectra where the IVAR is 0 for the entire wavelength grid is discarded. Parameters ---------- dir_name : str Directory to load exposure from. spec_number : int Spectrograph number to load. fibers : numpy.ndarray, optional Array of length 500 indicating whether each fiber should be loaded. True if the fiber should be loaded, False otherwise. Defaults to True for all 500 fibers. out_grid : numpy.ndarray, optional The wavelength grid to rebin the loaded exposure to. Defaults to the logarithmic QuasarNET grid. Returns ------- X_out : numpy.ndarray Renormalized and rebinned spectra. Output spectra will have shape `(nspectra, nbins)` where `nbins=len(out_grid)` for the QuasarNet wavelength grid, 443 for logarithmic or 458 for linear. w : numpy.ndarray Array of length `sum(fibers == True)` where each element is True if the spectra was kept in `X_out` and False if the spectra was discarded. See Also --------- load_desi_daily : Load a daily exposure. """ assert len(fibers) == 500, ("fibers input must include True/False" " for all 500 fibers.") assert 0 <= spec_number and spec_number <= 9, ("spec_number must be" " between 0 and 9.") file_loc = Path(dir_name) exp_id = file_loc.parts[-1] # Load each cam sequentially, then rebin and merge # We will be rebinning down to 443, which is the input size of QuasarNet nfibers = np.sum(fibers > 0) X_out = np.zeros((nfibers, 443)) # ivar_out is the weights out, i.e. the ivar, we use this for normalization # Use zeros_like so we only have to change one ivar_out = np.zeros_like(X_out) cams = ["B", "R", "Z"] for c in cams: im_path = file_loc / f"cframe-{c.lower()}{spec_number}-{exp_id}.fits" with fitsio.FITS(im_path) as h: # Load the flux and ivar flux = h["FLUX"].read()[fibers, :] ivar = h["IVAR"].read()[fibers, :] w_grid = h["WAVELENGTH"].read() # Rebin the flux and ivar new_flux, new_ivar = rebin(flux, ivar, w_grid, out_grid=out_grid) X_out += new_flux ivar_out += new_ivar non_zero = ivar_out != 0 X_out[non_zero] /= ivar_out[non_zero] nonzero_weights = np.sum(ivar_out, axis=1) != 0 print(f"{nfibers - np.sum(nonzero_weights)} spectra with zero weights") X_out = X_out[nonzero_weights] ivar_out = ivar_out[nonzero_weights] X_out = renormalize(X_out, ivar_out) return X_out, np.where(nonzero_weights)[0]
[docs] def load_desi_coadd(filename, rows=None, out_grid=wave): """Load and renormalize a DESI coadded spectrographic exposure. This method will load a coadd file and renormalize as follows. First, spectra are rebinned to the QuasarNet wavelength grid. Rebinned spectra are divided by the rebinned IVAR to reweight the spectra. Next, rebinned spectra are normalized by subtracting the weighted mean of the spectra and then dividing the resultant spectra by its weighted rms. The rebinned IVAR is used for weighting. Any spectra where the IVAR is 0 for the entire wavelength grid is discarded. Parameters ---------- filename : str Full path and filename of the coadd file to load. rows : numpy.ndarray, optional. Boolean array indicating whether each row should be loaded. True if the row should be loaded, False otherwise. Defaults to None, which loads all rows. out_grid : numpy.ndarray, optional The wavelength grid to rebin the loaded exposure to. Returns ------- X_out : numpy.ndarray Renormalized and rebinned spectra. Output spectra will have shape `(nspectra, nbins)` where `nbins=len(out_grid)` for the QuasarNet wavelength grid, 443 for logarithmic or 458 for linear. w : numpy.ndarray Array of length `sum(rows == True)` where each element is True if the spectra was kept in `X_out` and False if the spectra was discarded. See Also -------- load_desi_daily : Load a daily exposure. """ cams = ["B", "R", "Z"] with fitsio.FITS(filename) as h: # Load each cam sequentially, then rebin and merge # We will be rebinning down to 443, the input size of QuasarNet if rows is None: nfibers = len(h['B_FLUX'].read()) rows = np.ones(nfibers, dtype='bool') else: nfibers = np.sum(rows > 0) X_out = np.zeros((nfibers, len(out_grid))) # ivar_out is the weights out, we use this for normalization # Use zeros_like so we only have to change one ivar_out = np.zeros_like(X_out) for c in cams: fluxname = f"{c}_FLUX" ivarname = f"{c}_IVAR" wname = f"{c}_WAVELENGTH" # Load the flux and ivar flux = h[fluxname].read()[rows, :] ivar = h[ivarname].read()[rows, :] w_grid = h[wname].read() # Rebin the flux and ivar new_flux, new_ivar = rebin(flux, ivar, w_grid, out_grid=out_grid) X_out += new_flux ivar_out += new_ivar non_zero = ivar_out != 0 X_out[non_zero] /= ivar_out[non_zero] nonzero_weights = np.sum(ivar_out, axis=1) != 0 # f"{nfibers - np.sum(nonzero_weights)} spectra with zero weights" X_out = X_out[nonzero_weights] ivar_out = ivar_out[nonzero_weights] X_out = renormalize(X_out, ivar_out) return X_out, np.where(nonzero_weights)[0]
[docs] def load_desi_daily(night, exp_id, spec_number, fibers=np.ones(500, dtype="bool"), w_grid=wave): """Load and renormalize a daily DESI spectrographic exposure. This method will load B, R and Z cframe files in sequence. First, spectra are rebinned to the QuasarNet wavelength grid. Rebinned spectra are divided by the rebinned IVAR to reweight the spectra. Next, rebinned spectra are normalized by subtracting the weighted mean of the spectra and then dividing the resultant spectra by its weighted rms. The rebinned IVAR is used for weighting. Any spectra where the IVAR is 0 for the entire wavelength grid is discarded. Parameters ---------- night : int or str Night on which the exposure was taken. exp_id : str Exposure ID of the exposure. spec_number : int Spectrograph number to load. fibers : numpy.ndarray, optional. Array of length 500 indicating whether each fiber should be loaded. True if the fiber should be loaded, False otherwise. Defaults to True for all 500 fibers. w_grid : numpy.ndarray, optional The wavelength grid to rebin the loaded exposure to. Returns ------- X_out : numpy.ndarray Renormalized and rebinned spectra. Output spectra will have shape `(nspectra, nbins)` where `nbins=len(out_grid)` for the QuasarNet wavelength grid, 443 for logarithmic or 458 for linear. w : numpy.ndarray Array of length `sum(fibers == True)` where each element is True if the spectra was kept in `X_out` and False if the spectra was discarded. See Also -------- load_desi_exposure : Used by load_desi_daily to load the given exposure. load_desi_coadd : Load a coadded exposure. """ assert len(fibers) == 500, ("fibers input must include True/False" " for all 500 fibers.") assert 0 <= spec_number and spec_number <= 9, ("spec_number must be" " between 0 and 9.") # For now load daily cframes files # TODO: add support for loading arbitrary cframes. # TODO: Add support for loading by tile id + e rather than date + e root = "/global/cfs/cdirs/desi/spectro/redux/daily/exposures" file_loc = Path(root, night, exp_id) return load_desi_exposure(file_loc, spec_number, fibers, w_grid=w_grid)
# BOSS Related IO below this point ###############################################################################
[docs] def read_spall(file_loc): """ Read metadata from a spAll file. Parameters ---------- file_loc : string or Path Full path and filename of the spAll file to read. Returns ------- tid : numpy.ndarray Array of integer THING_IDs. pmf2tid : dict Dictionary mapping (PLATE, MJD, FIBERID) to THING_ID. """ # Open the file, and read plate, mjd, fiberid, thing_id, specprimary. # read() is faster than using [:] since we can read all the columns at once. with fitsio.FITS(file_loc) as h: d = h[1].read(columns=["PLATE", "MJD", "FIBERID", "THING_ID", "SPECPRIMARY"]) # Need to cast THING_ID to int and it's easier to read to do it here. tid = d["THING_ID"].astype(int) pmf2tid = {(p, m, f): t for p, m, f, t, s in zip(d["PLATE"], d["MJD"], d["FIBERID"], tid, d["SPECPRIMARY"])} return tid, pmf2tid