diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 677a7d428..646ae725b 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -87,6 +87,10 @@ import_check: stage: system tests script: - python .travis/test-all-mod.py + - python -m pytest -q MonteCarloMarginalizeCode/Code/test/test_lisa_response_import.py + - python -m pytest -q MonteCarloMarginalizeCode/Code/test/test_lisa_lalsimutils_compat.py + - python -m pytest -q MonteCarloMarginalizeCode/Code/test/test_lisa_operational_synthetic.py + - python -m pytest -q MonteCarloMarginalizeCode/Code/test/test_lisa_helper_contract.py sim_manager_check: stage: unit tests diff --git a/.travis/test-all-bin.sh b/.travis/test-all-bin.sh index d71bdcbd6..f040ecd74 100755 --- a/.travis/test-all-bin.sh +++ b/.travis/test-all-bin.sh @@ -17,6 +17,10 @@ set -e # loop over all bin/ scripts for EXE in MonteCarloMarginalizeCode/Code/bin/*; do + if [[ ! -f ${EXE} ]]; then + echo " Not file : " ${EXE} + continue + fi # skip scripts with explicit bilby dependence if [[ ${EXE} == *'calibration_reweighting.py' ]]; then continue @@ -70,6 +74,16 @@ for EXE in MonteCarloMarginalizeCode/Code/bin/*; do echo "Hyperpipe " ${EXE} continue fi + # skip optional JAX driver; base CI does not install the jax extras + if [[ ${EXE} == *"integrate_likelihood_extrinsic_jax" ]]; then + echo "JAX optional " ${EXE} + continue + fi + # skip optional HTCondor2 driver; base CI does not install htcondor2 + if [[ ${EXE} == *"cepp_basic_htcondor" ]]; then + echo "HTCondor2 optional " ${EXE} + continue + fi # skip tests that require condor environment if [[ ${EXE} == *"check_CIP_complete_work.py" ]]; then continue diff --git a/.travis/test-all-mod.py b/.travis/test-all-mod.py index 8e7765051..3d0a18e6a 100755 --- a/.travis/test-all-mod.py +++ b/.travis/test-all-mod.py @@ -57,6 +57,20 @@ r"\ANo module named 'cupy'\Z", r"\ANo module named asimov\Z", r"\ANo module named 'asimov'\Z", + r"\ANo module named jax\Z", + r"\ANo module named 'jax'\Z", + r"\ANo module named optax\Z", + r"\ANo module named 'optax'\Z", + r"\ANo module named equinox\Z", + r"\ANo module named 'equinox'\Z", + r"\ANo module named tinygp\Z", + r"\ANo module named 'tinygp'\Z", + r"\ANo module named numpyro\Z", + r"\ANo module named 'numpyro'\Z", + r"\ANo module named flowMC\Z", + r"\ANo module named 'flowMC'\Z", + r"\ANo module named htcondor2\Z", + r"\ANo module named 'htcondor2'\Z", ]))) diff --git a/CHANGES.rst b/CHANGES.rst index ac1b393c9..cc70904b7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,6 +2,11 @@ 0.0.18.0 ------------ development tree is rift_O4d. + - (rc0) O4d base refresh, from rift_O4c to rift_O4d: Python/numpy CI modernization (py3.10-py3.13, + numpy 2.x checks), Asimov/RIFT smoke tests, docs deployment, pluggable workflow backends and + simulation-manager prototypes, distance-grid/distance-slice likelihood export, container-family and pixi/SWIG + canaries, hyperpipeline ASCII workflow support, parsimonious-placement preview, EFPE and standalone NR-frame + utilities, plot_RIFT diagnostics, and GPU/CuPy portability fixes. ** (rc0) - parsimonious-placement (preview): new RIFT.misc.tracer_placement engine (SMC+MALA, birth-death, @@ -37,6 +42,24 @@ development tree is rift_O4d. - ascii data format; waveforms (epfe); container auto-selection framework (first draft); cepp_alternate now feature-parity; mcsamplerEnsemble full GPU path - multiple demos + - (rc1 pending) In-loop calibration marginalization: move calibration draws into ILE, add loop and fused + GPU kernels (including distance-marginalized and phase-marginalized paths), export cal/time diagnostics, + account for calibration Monte-Carlo error in reported uncertainty, add adaptive/pilot calibration proposal + scaffolding, and include a runnable calmarg demo with OSG/container transfer fixes. + - (rc1 pending) Extrinsic proposal handoff and sampler robustness: save and consolidate extrinsic GMM + breadcrumbs across iterations, seed later ILE jobs from those proposals, add GMM proposal/adaptation + controls and diagnostics, improve log-domain covariance/ESS handling, and thread the cal/extrinsic seed + barriers through CEPP/puffball workflows. + - (rc1 pending) Differentiable JAX likelihood/export work: optional jax_gp interpolators, differentiable + export artifacts, quad-GP and RF/GP validation tooling, jax_cip experiments, and a JAX ILE driver with + AD-compatible extrinsic likelihood, distance/phase marginalization, network coordinates, and gradient-aware + samplers. These remain optional extras and are skipped by base CI when JAX is not installed. + - (rc1 pending) LISA support: add a LISA compatibility layer, response import path, standalone helper + contract, LISA ILE scaffold, and synthetic operational/import tests. + - (rc1 pending) EOS/coordinate and workflow updates: add NMB/PCA/tabular EOS sequence dispatch and + single-EOS-index evidence support, coordinate-plugin hooks for posterior plotting and puffing, Morisaki-frame + and aligned-spin coordinate helpers for GP/JAX studies, container-universe/OSDF scitokens support, and + per-machine container image selection. 0.0.17.9 ------------ diff --git a/MonteCarloMarginalizeCode/Code/RIFT/LISA/__init__.py b/MonteCarloMarginalizeCode/Code/RIFT/LISA/__init__.py new file mode 100644 index 000000000..01d2971e3 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/LISA/__init__.py @@ -0,0 +1,2 @@ +"""LISA-specific response and analysis helpers for RIFT.""" + diff --git a/MonteCarloMarginalizeCode/Code/RIFT/LISA/lalsimutils_compat.py b/MonteCarloMarginalizeCode/Code/RIFT/LISA/lalsimutils_compat.py new file mode 100644 index 000000000..71441effb --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/LISA/lalsimutils_compat.py @@ -0,0 +1,283 @@ +"""Compatibility helpers for the standalone LISA ILE fork. + +These routines are adapted from Aasim Z. Jan's LISA-RIFT +``lisa_rift_paper`` branch. Keep LISA-only shims here while the normal +``RIFT.lalsimutils`` and ILE driver continue to evolve independently. +""" + +import os +import sys + +import h5py +import lal +import lalsimulation as lalsim +import numpy as np + +import RIFT.lalsimutils as lalsimutils + + +def _filter_modes(hlms, modes): + if modes is None: + return hlms + + requested_modes = {tuple(map(int, mode)) for mode in modes} + return {mode: hlm for mode, hlm in hlms.items() if mode in requested_modes} + + +def hlmoff_for_LISA( + P, + Lmax=4, + modes=None, + fd_standoff_factor=0.964, + fd_alignment_postevent_time=None, + path_to_NR_hdf5=None, + **kwargs +): + """Generate frequency-domain modes with the sizing expected by LISA ILE.""" + assert Lmax >= 2 + assert (not np.isnan(P.m1)) and (not np.isnan(P.m2)), "Masses are NaN." + + extra_waveform_args = {} + if "extra_waveform_args" in kwargs: + extra_waveform_args.update(kwargs["extra_waveform_args"]) + extra_params = P.to_lal_dict_extended(extra_args_dict=extra_waveform_args) + + fNyq = 0.5 / P.deltaT + TDlen = int(1.0 / (P.deltaT * P.deltaF)) + if fd_alignment_postevent_time: + if fd_alignment_postevent_time >= TDlen * P.deltaT / 2: + print( + " Warning: fd alignment postevent time requested incompatible with short duration ", + file=sys.stderr, + ) + + if path_to_NR_hdf5 is not None: + hlms_struct = lalsimutils.hlmoft_from_NRhdf5( + path_to_NR_hdf5, + P, + Lmax, + only_mode=modes, + taper_percent=10, + beta=8, + verbose=True, + ) + return {mode: lalsimutils.DataFourier(hlms_struct[mode]) for mode in hlms_struct} + + fd_mode_approximants = { + lalsimutils.lalIMRPhenomHM, + lalsimutils.lalIMRPhenomXPHM, + lalsimutils.lalIMRPhenomXHM, + } + if P.approx in fd_mode_approximants: + hlms_struct = lalsim.SimInspiralChooseFDModes( + P.m1, + P.m2, + P.s1x, + P.s1y, + P.s1z, + P.s2x, + P.s2y, + P.s2z, + P.deltaF, + P.fmin * fd_standoff_factor, + fNyq, + P.fref, + P.phiref, + P.dist, + P.incl, + extra_params, + P.approx, + ) + hlmsdict = lalsimutils.SphHarmFrequencySeries_to_dict(hlms_struct, Lmax) + hlmsdict = _filter_modes(hlmsdict, modes) + return { + mode: lal.ResizeCOMPLEX16FrequencySeries(hlm, 0, TDlen) + for mode, hlm in hlmsdict.items() + } + + if P.approx in {lalsimutils.lalNRHybSur3dq8, lalsimutils.lalIMRPhenomD}: + hlms_struct = lalsimutils.hlmoff(P, Lmax=Lmax) + hlmsdict = lalsimutils.SphHarmFrequencySeries_to_dict(hlms_struct, Lmax) + hlmsdict = _filter_modes(hlmsdict, modes) + for mode, hlm in list(hlmsdict.items()): + if not (1 / hlm.deltaF == P.deltaT * TDlen): + print( + "WARNING: RESIZING IN FD DOMAIN " + f"(1/deltaF = {1 / hlm.deltaF}, deltaT*TDlen = {P.deltaT * TDlen}), " + "THIS SHOULD NOT BE HAPPENING." + ) + hlmsdict[mode] = lal.ResizeCOMPLEX16FrequencySeries(hlm, 0, TDlen) + return hlmsdict + + raise ValueError(f"Unsupported LISA FD mode approximant: {P.approx}") + + +def _cache_path_for_channel(fname, channel): + cache_data = np.loadtxt(fname, dtype=str) + if cache_data.ndim == 1: + cache_data = cache_data.reshape(1, len(cache_data)) + + channel_prefix = channel[0] + for row in cache_data: + if row[0] == channel_prefix: + raw_path = row[-1] + if "localhost" in raw_path: + raw_path = raw_path.split("localhost", 1)[1] + elif raw_path.startswith("file://"): + raw_path = raw_path[len("file://"):] + return os.path.expanduser(raw_path) + + raise ValueError(f"Could not find channel {channel!r} in cache {fname!r}") + + +def frame_h5_to_hoff(fname, channel, start=None, stop=None, verbose=True): + """Read LISA frequency-domain HDF5 data from a cache entry.""" + if verbose: + print(" ++ Loading from cache ", fname, channel) + + path_to_h5 = _cache_path_for_channel(fname, channel) + if verbose: + print(f"Reading h5 file {path_to_h5}") + + with h5py.File(path_to_h5, "r") as data: + hoff = lal.CreateCOMPLEX16FrequencySeries( + "hoff", + data.attrs["epoch"], + data.attrs["f0"], + data.attrs["deltaF"], + lalsimutils.lsu_HertzUnit, + int(data.attrs["length"]), + ) + hoff.data.data = data["data"][()] + + return hoff + + +def frame_h5_to_hoft(fname, channel, start=None, stop=None, verbose=True): + """Read LISA time-domain HDF5 data from a cache entry.""" + if verbose: + print(" ++ Loading from cache ", fname, channel) + + path_to_h5 = _cache_path_for_channel(fname, channel) + if verbose: + print(f"Reading h5 file {path_to_h5}") + + with h5py.File(path_to_h5, "r") as data: + hoft = lal.CreateREAL8TimeSeries( + "hoft", + data.attrs["epoch"], + data.attrs["f0"], + data.attrs["deltaT"], + lal.DimensionlessUnit, + int(data.attrs["length"]), + ) + hoft.data.data = data["data"][()] + + return hoft + + +def frame_data_to_non_herm_hoff( + fname, + channel, + start=None, + stop=None, + TDlen=0, + window_shape=0.0, + verbose=True, + deltaT=None, + h5_frame=False, +): + """Read frame/HDF5 time-domain data and FFT to two-sided frequency data.""" + if not h5_frame: + return lalsimutils.frame_data_to_non_herm_hoff( + fname, + channel, + start=start, + stop=stop, + TDlen=TDlen, + window_shape=window_shape, + verbose=verbose, + deltaT=deltaT, + ) + + ht = frame_h5_to_hoft(fname, channel, start, stop, verbose) + + tmplen = ht.data.length + if TDlen == -1: + TDlen = tmplen + elif TDlen == 0: + TDlen = lalsimutils.nextPow2(tmplen) + else: + assert TDlen >= tmplen + + ht = lal.ResizeREAL8TimeSeries(ht, 0, TDlen) + hoftC = lal.CreateCOMPLEX16TimeSeries( + "h(t)", ht.epoch, ht.f0, ht.deltaT, ht.sampleUnits, TDlen + ) + hoftC.data.data = ht.data.data + 0j + fwdplan = lal.CreateForwardCOMPLEX16FFTPlan(TDlen, 0) + hf = lal.CreateCOMPLEX16FrequencySeries( + "Template h(f)", + ht.epoch, + ht.f0, + 1.0 / ht.deltaT / TDlen, + lalsimutils.lsu_HertzUnit, + TDlen, + ) + lal.COMPLEX16TimeFreqFFT(hf, hoftC, fwdplan) + if verbose: + print( + " ++ Loaded data h(f) of length n= ", + hf.data.length, + " (= ", + len(hf.data.data) * ht.deltaT, + "s) at sampling rate ", + 1.0 / ht.deltaT, + ) + return hf + + +def print_params_lisa(self, show_system_frame=False): + """LISA-flavored parameter printer used by the standalone ILE fork.""" + print("This ChooseWaveformParams has the following parameter values:") + print(f"m1 = {self.m1 / lalsimutils.lsu_MSUN / 1e3:0.3f} x 1e3 (Msun)") + print(f"m2 = {self.m2 / lalsimutils.lsu_MSUN / 1e3:0.3f} x 1e3 (Msun)") + print("s1x =", self.s1x) + print("s1y =", self.s1y) + print("s1z =", self.s1z) + print("s2x =", self.s2x) + print("s2y =", self.s2y) + print("s2z =", self.s2z) + if show_system_frame: + self.print_params(show_system_frame=show_system_frame) + print("lambda1 =", self.lambda1) + print("lambda2 =", self.lambda2) + print("inclination =", self.incl) + print("distance =", self.dist / 1.0e9 / lalsimutils.lsu_PC, "(Gpc)") + print("reference orbital phase =", self.phiref) + print("polarization angle =", self.psi) + print("eccentricity = ", self.eccentricity) + print("reference time = ", float(self.tref), "(s)") + print("detector is: LISA") + print("ecliptic latitude (beta):", self.theta, "(radians)") + print("ecliptic longitude (lambda):", self.phi, "(radians)") + print("starting frequency is =", self.fmin, "(Hz)") + print("reference frequency is =", self.fref, "(Hz)") + print("Max frequency is =", self.fmax, "(Hz)") + print("time step =", self.deltaT, "(s) <==>", 1.0 / self.deltaT, "(Hz) sample rate") + print("freq. bin size is =", self.deltaF, "(Hz)") + + +def install_choose_waveform_print_params_lisa(): + """Install the LISA printer on ChooseWaveformParams for the LISA fork.""" + lalsimutils.ChooseWaveformParams.print_params_lisa = print_params_lisa + + +__all__ = [ + "frame_data_to_non_herm_hoff", + "frame_h5_to_hoff", + "frame_h5_to_hoft", + "hlmoff_for_LISA", + "install_choose_waveform_print_params_lisa", + "print_params_lisa", +] diff --git a/MonteCarloMarginalizeCode/Code/RIFT/LISA/response/LISA_response.py b/MonteCarloMarginalizeCode/Code/RIFT/LISA/response/LISA_response.py new file mode 100644 index 000000000..a8855671f --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/LISA/response/LISA_response.py @@ -0,0 +1,727 @@ +"""This code is based on http://arxiv.org/abs/1806.10734 and 10.1103/PhysRevD.103.083011, with parts of code being taken from BBHx's C code. The response has been validated again BBHx with mismatch of around 10^(-10). This assumes fixed LISA arm length and is not the best formalism for Stellar mass binaries.""" + +# TO DO: +## Better documentation +## Impact of different fourier convention. Right now we follow theirs, not lal's. + +# Future work +## 1) Add a different response code for Stellar mass binaries. + +# Speed ups: +# https://stackoverflow.com/questions/49493482/numpy-np-multiply-vs-operator np.multiply vs * (no speedup) +# https://stackoverflow.com/questions/25870923/how-to-square-or-raise-to-a-power-elementwise-a-2d-numpy-array np.square vs **2 (no speedup for array of length 1e7). np.power is slow +# https://stackoverflow.com/questions/49459661/differences-between-numpy-divide-and-python-divide np.divide vs / (no speedup) + +import numpy as np +import lal +import RIFT.lalsimutils as lsu +import RIFT.LISA.lalsimutils_compat as lisa_lalsimutils_compat +import sys +import h5py + +__author__ = "A. Jan" +########################################################################################### +# Constants +########################################################################################### +e = 0.004824185218078991 +omega0 = 1.9909865927683788 * 10**(-7) #1/seconds +a = 149597870700. #meters +C_SI = 299792458. +L = 2*np.sqrt(3)*a*e +YRSID_SI = 31558149.763545603 + +########################################################################################### +# FUNCTIONS +########################################################################################### +def create_lal_frequency_series(frequency_values, frequency_series, deltaF, epoch = 950000000, f0 = 0.0): + """A helper function to create lal COMPLEX16FrequencySeries. + Args: + frequency_values (numpy.array): Frequency values at which the series is defined. + frequency_series (numpy.array): Corresponding strain in frequency domain. + epoch (float): Needed to create COMPLEX16FrequencySeries, by default it is 950000000. + f0 (float): Needed to create COMPLEX16FrequencySeries, by default it is 0.0 . + Output: + lal.COMPLEX16FrequencySeries object""" + assert len(frequency_values) == len(frequency_series), "frequency_values and frequency_series don't have the same length." + hf_lal = lal.CreateCOMPLEX16FrequencySeries("hf", epoch, f0, deltaF, lal.HertzUnit, len(frequency_values)) + hf_lal.data.data = frequency_series + return hf_lal + +def create_lal_COMPLEX16TimeSeries(deltaT, time_series, epoch = 950000000, f0 = 0.0, data_is_real = True): + """A helper function to create lal COMPLEX16TimeSeries. + Args: + deltaT (float): time step (1/sampling rate). + time_series (numpy.array): strain in time domain. + epoch (float): Needed to create COMPLEX16TimeSeries, by default it is 950000000. + f0 (float): Needed to create COMPLEX16TimeSeries, by default it is 0.0 . + Output: + lal.COMPLEX16TimeSeries object""" + ht_lal = lal.CreateCOMPLEX16TimeSeries("ht_lal", epoch, f0, deltaT, lal.DimensionlessUnit, len(time_series)) + if data_is_real: + ht_lal.data.data = time_series + 0j + else: + ht_lal.data.data = time_series + return ht_lal + + +def convert_double_sided_to_single_sided(frequency_values, frequency_series, data_defined="negative"): + assert len(frequency_values) == len(frequency_series), "frequency_values and frequency_series don't have the same length." + if data_defined == "negative": + print("Negative") + index = np.argwhere(frequency_values<=0).flatten() + hf_onesided = create_lal_frequency_series(frequency_values[index], np.conj(frequency_series[index][::-1])) # do I need to conjugate? + assert len(frequency_series)//2 + 1 == hf_onesided.data.length + return hf_onesided + elif data_defined == "positive": + print("Positve") + index = np.argwhere(frequency_values>=0).flatten() + hf_onesided = create_lal_frequency_series(frequency_values[index], frequency_series[index]) + assert len(frequency_series)//2 + 1 == hf_onesided.data.length + return hf_onesided + else: + raise ValueError("data_defined must be either 'negative' or 'positive'") + +def frequency_series_double_sided(frequency_values, frequency_series, data_defined = "negative"): + hf_small = convert_double_sided_to_single_sided(frequency_values, frequency_series, data_defined) + tmp = np.zeros(len(frequency_series), dtype=complex) + tmp[:len(hf_small.data.data)] = np.conj(hf_small.data.data[::-1]) + tmp[len(hf_small.data.data)-1:] = (hf_small.data.data[:-1]) + hf = create_lal_frequency_series(frequency_values, tmp) + return hf + +def get_fvals(frequency_series): + """A function to evaulate frequency values of a COMPLEX16FrequencySeries. Goes from [-fNyq, fNyq - deltaF]. + Args: + frequency_series (COMPLEX16FrequencySeries): + Output: + frequency array (numpy.array)""" + fvals = -frequency_series.deltaF*np.arange(frequency_series.data.length//2, -frequency_series.data.length//2, -1) + return fvals + + + +def get_Ylm(inclination, phiref, l ,m, s = -2): + """A function to call spherical harmonics, should change to calling the GPU version in future but for now, I am only using to create injections so it should be fine. + Args: + inclination (float): inclination in SSB frame, + phiref (float): phase at coalescence, + l, m: modes, + s: spin weight, -2 by default. + Output: + Spin weighted spherical harmonics (complex)""" + return lal.SpinWeightedSphericalHarmonic(inclination, phiref, s, l, m) + +def get_closest_index(array, value): + """A function that gives you the index at which the array has a value closest to the one you want. + Args: + array (np.array): + value (float): + Output: + index (float)""" + return np.argmin(np.abs(array - value)) + +def transformed_Hplus_Hcross(beta, lamda, psi, theta, phiref, l, m): + """This function transforms the plus and cross polarization from wave frame to SSB frame. + Args: + beta (float) = ecliptic latitude, + lamda (float) = ecliptic longitude, + psi (float) = polarization + Returns: + transformed Plm (used to create the transfer function, equation 21 (http://arxiv.org/abs/2003.00357) + """ + + # defining cos and sin for ease (checked) + cl, sl = np.cos(lamda), np.sin(lamda) + cb, sb = np.cos(beta), np.sin(beta) + cp, sp = np.cos(psi), np.sin(psi) + + # polarization in waveframe (checked) + Hp = [[1,0,0], [0,-1,0], [0,0,0]] + Hc = [[0,1,0], [1,0,0] , [0,0,0]] + + # Rotation matrix (checked) + O1 = np.zeros((3,3)) + + O1[0,0] = cp*sl - cl*sb*sp + O1[0,1] = -cl*cp*sb - sl*sp + O1[0,2] = -cb*cl + + O1[1,0] = -cl*cp - sb*sl*sp + O1[1,1] = -cp*sb*sl + cl*sp + O1[1,2] = -cb*sl + + O1[2,0] = cb*sp + O1[2,1] = cb*cp + O1[2,2] = -sb + + # Transpose of Rotation matrix (checked) + TO1 = np.zeros((3,3)) + + TO1[0,0] = cp*sl - cl*sb*sp + TO1[0,1] = -cp*cl - sl*sb*sp + TO1[0,2] = cb*sp + + TO1[1,0] = -cl*cp*sb - sl*sp + TO1[1,1] = -sl*cp*sb + cl*sp + TO1[1,2] = cb*cp + + TO1[2,0] = -cb*cl + TO1[2,1] = -cb*sl + TO1[2,2] = -sb + + # Ylm factors (checked) + # For injection, we shouldn't take -phiref as we do in marginalization. Confirm with ROS. + Ylm = get_Ylm(theta, phiref, l, m, -2) + Y_lm = (-1)**(l) * np.conj(get_Ylm(theta, phiref, l, -m, -2)) + + Yfactorplus = 0.5*(Ylm + Y_lm) + Yfactorcross = 0.5*1j*(Ylm - Y_lm) + + Plm = Yfactorplus*np.matmul(O1, np.matmul(Hp,TO1)) + Yfactorcross*np.matmul(O1, np.matmul(Hc,TO1)) + return Plm + + +def get_tf_from_phase(hlm, fmax, debug = False):#tested + """This function differentiates phase to get tf. Similar to pycbc's time_from_frequencyseries (waveforms/utils.py) function but does not include their discontinuity check. (Now it does have those checks). + Args: + hlm (COMPLEX16FrequencySeries): The mode for which you are calculating tf , + fmax (float): maximum frequency (fNyq for RIFT). + Returns: + tf (numpy.array): Time, + frequency (numpy.array): Frequency (numpy.array). + """ + # send in hlm not shifted in time + + # get frequency and mode data + freq = np.arange(-fmax, fmax+hlm.deltaF, hlm.deltaF) + if debug: + print(f"len(freq) = {len(freq)}, freq[0] = {freq[0]} Hz, freq[-1] = {freq[-1]} Hz, len(hlm) = {hlm.data.length}") + + # get amplitude and phase + phase = np.unwrap(np.angle(hlm.data.data)) + # compute tf = -1/(2pi) * d(phase)/df + phase = phase - phase[0] #Pycbc does this, doesn't change answer as expected. + dphi = np.unwrap(np.diff(phase)) + time = -dphi / (2.*np.pi*np.diff(freq)) + # diff reduces len by 1 so artifically increasing it by adding an extra zero at the end + tmp = np.zeros(len(time)+1) + tmp[:-1] = time + time = tmp + + # only focusing on f bins where data exists + nzidx = np.nonzero(abs(hlm.data.data))[0] + if debug: + print(nzidx) + print(f"tf[0](after stripping zeros) = {time[kmin]}, tf[-1](after stripping zeros) {time[kmax]}") + kmin, kmax = nzidx[0], nzidx[-2] + time[:kmin] = time[kmin] + time[kmax:] = time[kmax] + if debug: + print(f"len(time) = {len(time)}, len(freq) = {len(freq)}") + + # saving data + return time, freq[::-1] #inverting frequency array to match http://arxiv.org/abs/1806.10734i fourier convention, will change it as we validate these codes. + +def Evaluate_Gslr(tf, f, H, beta, lamda): + """This function takes in tf, f, Plm (from transformed_Hplus_Hcross), beta and lamda to generate transfer function for a given mode. yslr = Sum_l^r Gslr * hlm. + Args: + tf (numpy.array)= -1/2pi d(phase)/df, + f (numpy.array)= frequency array, + H (numpy.array)= The plus and cross polarization matrices transformed from wave to SSB frame , + beta (float) = ecliptic latitude, + lamda (float) = ecliptic longitude + Returns: + Transfer function L1 (numpy.array), Transfer function L2 (numpy.array), Transfer function L3 (numpy.array) + """ + alpha = omega0*tf + c, s = np.cos(alpha), np.sin(alpha) + k = np.array([-np.cos(beta)*np.cos(lamda), -np.cos(beta)*np.sin(lamda), -np.sin(beta)]) + p0 = np.array([a*c, a*s, np.zeros(len(tf))]) # (3, N) + kR = np.dot(k, p0) # (N,) + phaseRdelay = 2.*np.pi/C_SI *f*kR #(N,) + + p1L = np.array([-a*e*(1 + s*s), a*e*c*s, -a*e*np.sqrt(3)*c]) # (3, N) + p2L = np.array([a*e/2*(np.sqrt(3)*c*s + (1 + s*s)), a*e/2*(-c*s - np.sqrt(3)*(1 + c*c)), -a*e*np.sqrt(3)/2*(np.sqrt(3)*s - c)]) # (3, N) + p3L = np.array([a*e/2*(-np.sqrt(3)*c*s + (1 + s*s)), a*e/2*(-c*s + np.sqrt(3)*(1 + c*c)), -a*e*np.sqrt(3)/2*(-np.sqrt(3)*s - c)]) # (3, N) + + + n1 = np.array([-1./2*c*s, 1./2*(1 + c*c), np.sqrt(3)/2*s]) # (3, N) + kn1= np.dot(k, n1) #(N,) + n1Hn1 = np.einsum("ij,ji->i",n1.T, np.einsum("ij,jk", H, n1)) + + n2 = 1./4. * np.array([c*s - np.sqrt(3)*(1 + s*s), np.sqrt(3)*c*s - (1 + c*c), -np.sqrt(3)*s - 3*c]) + kn2= np.dot(k, n2) + n2Hn2 = np.einsum("ij,ji->i",n2.T, np.einsum("ij,jk", H, n2)) + + n3 = 1./4*np.array([c*s + np.sqrt(3)*(1 + s*s), -np.sqrt(3)*c*s - (1 + c*c), -np.sqrt(3)*s + 3*c]) + kn3= np.dot(k, n3) + n3Hn3 = np.einsum("ij,ji->i",n3.T, np.einsum("ij,jk", H, n3)) + + + + kp1Lp2L = np.dot(k, (p1L+p2L)) + kp2Lp3L = np.dot(k, (p2L+p3L)) + kp3Lp1L = np.dot(k, (p3L+p1L)) + kp0 = np.dot(k, p0) + + factorcexp0 = np.exp(1j*2.*np.pi*f/C_SI * kp0) + prefactor = np.pi*f*L/C_SI + + factorcexp12 = np.exp(1j*prefactor * (1.+kp1Lp2L/L)) + factorcexp23 = np.exp(1j*prefactor * (1.+kp2Lp3L/L)) + factorcexp31 = np.exp(1j*prefactor * (1.+kp3Lp1L/L)) + + factorsinc12 = np.sinc( (prefactor * (1.-kn3))/np.pi) + factorsinc21 = np.sinc( (prefactor * (1.+kn3))/np.pi) + factorsinc23 = np.sinc( (prefactor * (1.-kn1))/np.pi) + factorsinc32 = np.sinc( (prefactor * (1.+kn1))/np.pi) + factorsinc31 = np.sinc( (prefactor * (1.-kn2))/np.pi) + factorsinc13 = np.sinc( (prefactor * (1.+kn2))/np.pi) + + commonfac = 1j*prefactor*factorcexp0 + G12 = commonfac * n3Hn3 * factorsinc12 * factorcexp12 * np.exp(-1j*phaseRdelay) + G21 = commonfac * n3Hn3 * factorsinc21 * factorcexp12 * np.exp(-1j*phaseRdelay) + G23 = commonfac * n1Hn1 * factorsinc23 * factorcexp23 * np.exp(-1j*phaseRdelay) + G32 = commonfac * n1Hn1 * factorsinc32 * factorcexp23 * np.exp(-1j*phaseRdelay) + G31 = commonfac * n2Hn2 * factorsinc31 * factorcexp31 * np.exp(-1j*phaseRdelay) + G13 = commonfac * n2Hn2 * factorsinc13 * factorcexp31 * np.exp(-1j*phaseRdelay) + + x = np.pi*f*L/C_SI + z = np.exp(1j*2.*x) + + factor_convention = 2 + factorAE = 1j*np.sqrt(2)*np.sin(2.*x)*z + factorT = 2.*np.sqrt(2)*np.sin(2.*x)*np.sin(x)*np.exp(1j*3.*x) + + Araw = 0.5 * ( (1.+z)*(G31 + G13) - G23 - z*G32 - G21 - z*G12 ) + Eraw = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13 - G31) + (2.+z)*(G12 - G32) + (1.+2.*z)*(G21 - G23) ) + Traw = 1/np.sqrt(6) * (G21 - G12 + G32 - G23 + G13 - G31) + transferL1 = factor_convention * factorAE * Araw + transferL2 = factor_convention * factorAE * Eraw + transferL3 = factor_convention * factorT * Traw + + return transferL1, transferL2, transferL3 + + +########################################################################################### +# FUNCTIONS USED IN RECOVERY +########################################################################################### +def get_amplitude_phase(hf): #tested + # NOTE: Send in hlm not shifted in time, I add time shift in precomputation. + """This function splits h(f) into Amplitude A(f) and phase phase(f). We then express h_lm as A_lm(f) * exp(i*phase(f)). Note the sign in the exponent. + Args: + hf (CreateCOMPLEX16FrequencySeries): Frequency domain waveform. + Returns: + Amplitude as a function of frequency (numpy.array), phase as a function of frequency (numpy.array) + """ + return np.abs(hf.data.data), np.unwrap(np.angle(hf.data.data)) + + + +def get_tf_from_phase_dict(hlm, fmax, fref=None, debug=True, shift=True):#tested + """This function differentiates phase for each mode to get tf. Similar to pycbc's time_from_frequencyseries (waveforms/utils.py) function. + Args: + hlm (dict): mode dict generated by std_and_conj_hlmoff or any other hlm(f) function in RIFT, + fmax (float): maximum frequency (fNyq for RIFT). + Returns: + tf (dict): tf for each mode in hlm, + frequency (dict): corresponding frequency for each mode in hlm, + amplitude (dict): frequency domain amplitude for each mode, + phase (dict): frequency domain phase for each mode. + """ + # send in hlm not shifted in time + tf_dict = {} + freq_dict = {} + amp_dict = {} + phase_dict = {} + print("Computing time frequency correspondence for mode") + modes = np.array(list(hlm.keys())) + #freq = -lsu.evaluate_fvals(hlm[tuple(modes[0])]) # THIS CONSUMES MOST TIME + freq = -hlm[tuple(modes[0])].deltaF*np.arange(hlm[tuple(modes[0])].data.length//2, -hlm[tuple(modes[0])].data.length//2, -1) # this matches evaluate_fvals without using for loops + for mode in modes: + print(f"\n\tMode = {mode}") + mode = tuple(mode) + # get amplitude and phase + amp, phase = get_amplitude_phase(hlm[mode]) + # compute tf = -1/(2pi) * d(phase)/df + dphi = np.unwrap(np.diff(phase)) + time = np.divide(-dphi, (2.*np.pi*hlm[mode].deltaF)) + # diff reduces len by 1 so artifically increasing it by adding an extra zero at the end + tmp = np.zeros(len(time)+1) + tmp[:-1] = time + time = tmp + + # only focusing on f bins where data exists + # I had to introduce this statement since sometimes a mode doesn't have data (odd m modes are not excited for q=1, so the mode content is all zero.) + try: + nzidx = np.nonzero(abs(hlm[mode].data.data))[0] + kmin, kmax = nzidx[0], nzidx[-2] + time[:kmin] = time[kmin] + time[kmax:] = time[kmax] + except: + print(f"No data for {mode}") + pass + # saving data + tf_dict[mode] = time + freq_dict[mode] = freq[::-1] + amp_dict[mode] = amp + phase_dict[mode] = phase + + if shift: + modes = list(hlm.keys()) + print(f"Shifting of time and phase with fref = {fref}.") + assert (2,2) in modes, "(2,2) mode needs to be present." + # phase and tf shifts + if not fref: + # if fref not provided, set it to frequency at max (f^2 * A_{2,2}(f)) (BBHx) + fref = freq_dict[2,2][np.argmax(freq_dict[2,2]**2 * amp_dict[2,2])] # frequency at max (f^2 * A_{2,2}(f)) + + # find tf at fref + index_at_fref = get_closest_index(freq_dict[2,2], fref) + tf_22_current = tf_dict[2,2][index_at_fref] + phase_22_current = phase_dict[2,2][index_at_fref] + + time_shift = tf_22_current + reference_phase = 0.0 + + # for loop needs to start with (2,2) mode + modes.remove((2,2)) + modes.insert(0, (2,2)) + if debug: + print(f"tf[2,2] at fref ({freq_dict[2,2][index_at_fref]} Hz) before shift is {tf_22_current}s (phase[2,2] = {phase_22_current}).") + + # subtract that from all modes. tf for (2,2) needs to be zero at fref, I will add t_ref to all modes later (create_lisa_injections for injections and precompute for recovery), making tf=t_ref at fref. + for mode in modes: + if debug: + print(f"\tShifting {mode}") + tf_dict[mode] = tf_dict[mode] - time_shift # confirmed that I don't need to set all modes tf as 0. Conceptually, for the same time the other modes will be at a different frequency. + phase_dict[mode] = phase_dict[mode] - 2*np.pi*time_shift*freq_dict[mode] + if mode == (2,2): + phase_22_current = phase_dict[2,2][index_at_fref] + difference = reference_phase - phase_22_current + phase_dict[mode] = phase_dict[mode] + mode[1]/2 * difference + print(f"{mode}, phase = {phase_dict[mode][index_at_fref]}") + if debug: + print(f"tf[2,2] at fref ({fref} Hz) after shift is {tf_dict[2,2][index_at_fref]} (phase[2,2] = {phase_dict[2,2][index_at_fref]}).") + + return tf_dict, freq_dict, amp_dict, phase_dict + + +def get_beta_lamda_psi_terms_Hp(beta, lamda, psi): + """This function gives beta lamda psi terms for each term when we split up n_l * P_lm * n_l in equation 21 of http://arxiv.org/abs/2003.00357. We need this to bring out the psi dependence to marginalize over it. This gives those terms after we transform the plus polarization's frame to SSB frame. + Args: + beta (float): ecliptic latitude, + lamda (float): ecliptic longitude, + psi (numpy.array of shape (n,1): polarization angle array. + Return: + xx, xy, xz, yy, yz, zz terms (each is a numpy array of shape (n,1) + + """ + xx_term = (np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi))**2 + \ + (-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi))*(np.cos(lamda)*np.cos(psi)*np.sin(beta) + np.sin(lamda)*np.sin(psi)) + + xy_term = ((np.cos(psi)*np.sin(beta)*np.sin(lamda)-np.cos(lamda)*np.sin(psi)) * (-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi)) + + (np.cos(psi)*np.sin(lamda)-np.cos(lamda)*np.sin(beta)*np.sin(psi))*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)* np.sin(psi))) + \ + ((-np.cos(psi)*np.sin(beta)*np.sin(lamda)+ np.cos(lamda)*np.sin(psi))*(np.cos(lamda)*np.cos(psi)*np.sin(beta) + np.sin(lamda)*np.sin(psi)) + + (np.cos(psi)*np.sin(lamda)-np.cos(lamda)*np.sin(beta)*np.sin(psi))*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + + xz_term = (np.cos(beta)*np.sin(psi)*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi)) + + np.cos(beta)*np.cos(psi)*(np.cos(lamda)*np.cos(psi)*np.sin(beta) + np.sin(lamda)*np.sin(psi))) + \ + (np.cos(beta)*np.sin(psi)*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi)) - + np.cos(beta)*np.cos(psi)*(-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi))) + + yy_term = (np.cos(psi)*np.sin(beta)*np.sin(lamda) - np.cos(lamda)*np.sin(psi))*(-np.cos(psi)*np.sin(beta)*np.sin(lamda)+ np.cos(lamda)*np.sin(psi)) + \ + (-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))**2 + + yz_term = (np.cos(beta)*np.cos(psi)*(np.cos(psi)*np.sin(beta)*np.sin(lamda) - np.cos(lamda)*np.sin(psi)) + + np.cos(beta)*np.sin(psi)*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + \ + (-np.cos(beta)*np.cos(psi)*(-np.cos(psi)*np.sin(beta)*np.sin(lamda) + np.cos(lamda)*np.sin(psi)) + + np.cos(beta)*np.sin(psi)*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + + zz_term = (-np.cos(beta)**2 * np.cos(psi)**2 + np.cos(beta)**2 * np.sin(psi)**2) + + combined = np.vstack([xx_term, xy_term, xz_term, yy_term, yz_term, zz_term]) + return combined + +def get_beta_lamda_psi_terms_Hc(beta, lamda, psi): + """This function gives beta lamda psi terms for each term when we split up n_l * P_lm * n_l in equation 21 of http://arxiv.org/abs/2003.00357. We need this to bring out the psi dependence to marginalize over +it. This gives those terms after we transform the cross polarization's frame to SSB frame. + Args: + beta (float): ecliptic latitude, + lamda (float): ecliptic longitude, + psi (numpy.array of shape (n,1): polarization angle array. + Return: + xx, xy, xz, yy, yz, zz terms (each is a numpy array of shape (n,1) + + """ + xx_term = 2*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi))*(-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi)) + + xy_term = ((-np.cos(psi)*np.sin(beta)*np.sin(lamda) + np.cos(lamda)*np.sin(psi))*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi)) + + (-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi))*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + \ + ((-np.cos(psi)*np.sin(beta)*np.sin(lamda) + np.cos(lamda)*np.sin(psi))*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi)) + + (-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi))*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + + xz_term = (np.cos(beta)*np.cos(psi)*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi)) + + np.cos(beta)*np.sin(psi)*(-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi))) + \ + (np.cos(beta)*np.cos(psi)*(np.cos(psi)*np.sin(lamda) - np.cos(lamda)*np.sin(beta)*np.sin(psi)) + + np.cos(beta)*np.sin(psi)*(-np.cos(lamda)*np.cos(psi)*np.sin(beta) - np.sin(lamda)*np.sin(psi))) + + yy_term = 2*(-np.cos(psi)*np.sin(beta)*np.sin(lamda) + np.cos(lamda)*np.sin(psi))*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi)) + + yz_term = (np.cos(beta)*np.sin(psi)*(-np.cos(psi)*np.sin(beta)*np.sin(lamda) + np.cos(lamda)*np.sin(psi)) + + np.cos(beta)*np.cos(psi)*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + \ + (np.cos(beta)*np.sin(psi)*(-np.cos(psi)*np.sin(beta)*np.sin(lamda) + np.cos(lamda)*np.sin(psi)) + + np.cos(beta)*np.cos(psi)*(-np.cos(lamda)*np.cos(psi) - np.sin(beta)*np.sin(lamda)*np.sin(psi))) + + zz_term = 2 * np.cos(beta)**2 * np.cos(psi) * np.sin(psi) + + combined = np.vstack([xx_term, xy_term, xz_term, yy_term, yz_term, zz_term]) + return combined + + +def Evaluate_Gslr_test_2(tf, f, beta, lamda): + """This is the main function, takes in tf, f, beta and lamda to generate transfer function for a given mode for each xx, xy, xz, yy, yz and zz term. (need to explain this in paper) + Args: + tf (numpy.array)= -1/2pi d(phase)/df, + f (numpy.array)= frequency array, + beta (float) = ecliptic latitude, + lamda (float) = ecliptic longitude + Returns: + Transfer function L1 (numpy.array with xx, xy, xz, yy, yz, zz), Transfer function L2 (numpy.array with xx, xy, xz, yy, yz, zz), Transfer function L3 (numpy.array with xx, xy, xz, yy, yz, zz) + """ + alpha = omega0*tf + c, s = np.cos(alpha), np.sin(alpha) + k = np.array([-np.cos(beta)*np.cos(lamda), -np.cos(beta)*np.sin(lamda), -np.sin(beta)]) + p0 = np.array([a*c, a*s, np.zeros(len(tf))]) # (3, N) + kR = np.dot(k, p0) # (N,) + phaseRdelay = 2.*np.pi/C_SI *f*kR #(N,) + + p1L =np.array([-a*e*(1 + s*s), a*e*c*s, -a*e*np.sqrt(3)*c]) # (3, N) + p2L =np.array([a*e/2*(np.sqrt(3)*c*s + (1 + s*s)), a*e/2*(-c*s - np.sqrt(3)*(1 + c*c)), -a*e*np.sqrt(3)/2*(np.sqrt(3)*s - c)]) # (3, N) + p3L =np.array([a*e/2*(-np.sqrt(3)*c*s + (1 + s*s)), a*e/2*(-c*s + np.sqrt(3)*(1 + c*c)), -a*e*np.sqrt(3)/2*(-np.sqrt(3)*s - c)]) # (3, N) + + + n1 = np.array([-1./2*c*s, 1./2*(1 + c*c), np.sqrt(3)/2*s]) # (3, N) + kn1= np.dot(k, n1) #(N,) + + + n2 = 1./4. * np.array([c*s - np.sqrt(3)*(1 + s*s), np.sqrt(3)*c*s - (1 + c*c), -np.sqrt(3)*s - 3*c]) + kn2= np.dot(k, n2) + + + n3 = 1./4*np.array([c*s + np.sqrt(3)*(1 + s*s), -np.sqrt(3)*c*s - (1 + c*c), -np.sqrt(3)*s + 3*c]) + kn3= np.dot(k, n3) + + + kp1Lp2L = np.dot(k, (p1L+p2L)) + kp2Lp3L = np.dot(k, (p2L+p3L)) + kp3Lp1L = np.dot(k, (p3L+p1L)) + kp0 = np.dot(k, p0) + + factorcexp0 = np.exp(1j*2.*np.pi*f/C_SI * kp0) + prefactor = np.pi*f*L/C_SI + + factorcexp12 = np.exp(1j*prefactor * (1.+kp1Lp2L/L)) + factorcexp23 = np.exp(1j*prefactor * (1.+kp2Lp3L/L)) + factorcexp31 = np.exp(1j*prefactor * (1.+kp3Lp1L/L)) + + factorsinc12 = np.sinc( (prefactor * (1.-kn3))/np.pi) + factorsinc21 = np.sinc( (prefactor * (1.+kn3))/np.pi) + factorsinc23 = np.sinc( (prefactor * (1.-kn1))/np.pi) + factorsinc32 = np.sinc( (prefactor * (1.+kn1))/np.pi) + factorsinc31 = np.sinc( (prefactor * (1.-kn2))/np.pi) + factorsinc13 = np.sinc( (prefactor * (1.+kn2))/np.pi) + + commonfac = 1j*prefactor*factorcexp0 + G12_term = commonfac * factorcexp12 * np.exp(-1j*phaseRdelay) + G23_term = commonfac * factorcexp23 * np.exp(-1j*phaseRdelay) + G31_term = commonfac * factorcexp31 * np.exp(-1j*phaseRdelay) + + G12xx = G12_term * n3[0,:]*n3[0,:] * factorsinc12 + G21xx = G12_term * n3[0,:]*n3[0,:] * factorsinc21 + G23xx = G23_term * n1[0,:]*n1[0,:] * factorsinc23 + G32xx = G23_term * n1[0,:]*n1[0,:] * factorsinc32 + G31xx = G31_term * n2[0,:]*n2[0,:] * factorsinc31 + G13xx = G31_term * n2[0,:]*n2[0,:] * factorsinc13 + + G12xy = G12_term * n3[0,:]*n3[1,:] * factorsinc12 + G21xy = G12_term * n3[0,:]*n3[1,:] * factorsinc21 + G23xy = G23_term * n1[0,:]*n1[1,:] * factorsinc23 + G32xy = G23_term * n1[0,:]*n1[1,:] * factorsinc32 + G31xy = G31_term * n2[0,:]*n2[1,:] * factorsinc31 + G13xy = G31_term * n2[0,:]*n2[1,:] * factorsinc13 + + G12xz = G12_term * n3[0,:]*n3[2,:] * factorsinc12 + G21xz = G12_term * n3[0,:]*n3[2,:] * factorsinc21 + G23xz = G23_term * n1[0,:]*n1[2,:] * factorsinc23 + G32xz = G23_term * n1[0,:]*n1[2,:] * factorsinc32 + G31xz = G31_term * n2[0,:]*n2[2,:] * factorsinc31 + G13xz = G31_term * n2[0,:]*n2[2,:] * factorsinc13 + + G12yy = G12_term * n3[1,:]*n3[1,:] * factorsinc12 + G21yy = G12_term * n3[1,:]*n3[1,:] * factorsinc21 + G23yy = G23_term * n1[1,:]*n1[1,:] * factorsinc23 + G32yy = G23_term * n1[1,:]*n1[1,:] * factorsinc32 + G31yy = G31_term * n2[1,:]*n2[1,:] * factorsinc31 + G13yy = G31_term * n2[1,:]*n2[1,:] * factorsinc13 + + G12yz = G12_term * n3[1,:]*n3[2,:] * factorsinc12 + G21yz = G12_term * n3[1,:]*n3[2,:] * factorsinc21 + G23yz = G23_term * n1[1,:]*n1[2,:] * factorsinc23 + G32yz = G23_term * n1[1,:]*n1[2,:] * factorsinc32 + G31yz = G31_term * n2[1,:]*n2[2,:] * factorsinc31 + G13yz = G31_term * n2[1,:]*n2[2,:] * factorsinc13 + + G12zz = G12_term * n3[2,:]*n3[2,:] * factorsinc12 + G21zz = G12_term * n3[2,:]*n3[2,:] * factorsinc21 + G23zz = G23_term * n1[2,:]*n1[2,:] * factorsinc23 + G32zz = G23_term * n1[2,:]*n1[2,:] * factorsinc32 + G31zz = G31_term * n2[2,:]*n2[2,:] * factorsinc31 + G13zz = G31_term * n2[2,:]*n2[2,:] * factorsinc13 + + x = np.pi*f*L/C_SI + z = np.exp(1j*2.*x) + + factor_convention = 2 + #factor_convention = 2/np.sqrt(2) # for radler dataset + + factorAE = 1j*np.sqrt(2)*np.sin(2.*x)*z + factorT = 2.*np.sqrt(2)*np.sin(2.*x)*np.sin(x)*np.exp(1j*3.*x) + + Araw_xx = 0.5 * ( (1.+z)*(G31xx + G13xx) - G23xx - z*G32xx - G21xx - z*G12xx ) + Araw_xy = 0.5 * ( (1.+z)*(G31xy + G13xy) - G23xy - z*G32xy - G21xy - z*G12xy ) + Araw_xz = 0.5 * ( (1.+z)*(G31xz + G13xz) - G23xz - z*G32xz - G21xz - z*G12xz ) + Araw_yy = 0.5 * ( (1.+z)*(G31yy + G13yy) - G23yy - z*G32yy - G21yy - z*G12yy ) + Araw_yz = 0.5 * ( (1.+z)*(G31yz + G13yz) - G23yz - z*G32yz - G21yz - z*G12yz ) + Araw_zz = 0.5 * ( (1.+z)*(G31zz + G13zz) - G23zz - z*G32zz - G21zz - z*G12zz ) + + Eraw_xx = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13xx - G31xx) + (2.+z)*(G12xx - G32xx) + (1.+2.*z)*(G21xx - G23xx) ) + Eraw_xy = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13xy - G31xy) + (2.+z)*(G12xy - G32xy) + (1.+2.*z)*(G21xy - G23xy) ) + Eraw_xz = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13xz - G31xz) + (2.+z)*(G12xz - G32xz) + (1.+2.*z)*(G21xz - G23xz) ) + Eraw_yy = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13yy - G31yy) + (2.+z)*(G12yy - G32yy) + (1.+2.*z)*(G21yy - G23yy) ) + Eraw_yz = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13yz - G31yz) + (2.+z)*(G12yz - G32yz) + (1.+2.*z)*(G21yz - G23yz) ) + Eraw_zz = 0.5*1/np.sqrt(3) * ( (1.-z)*(G13zz - G31zz) + (2.+z)*(G12zz - G32zz) + (1.+2.*z)*(G21zz - G23zz) ) + + Traw_xx = 1/np.sqrt(6) * (G21xx - G12xx + G32xx - G23xx + G13xx - G31xx) + Traw_xy = 1/np.sqrt(6) * (G21xy - G12xy + G32xy - G23xy + G13xy - G31xy) + Traw_xz = 1/np.sqrt(6) * (G21xz - G12xz + G32xz - G23xz + G13xz - G31xz) + Traw_yy = 1/np.sqrt(6) * (G21yy - G12yy + G32yy - G23yy + G13yy - G31yy) + Traw_yz = 1/np.sqrt(6) * (G21yz - G12yz + G32yz - G23yz + G13yz - G31yz) + Traw_zz = 1/np.sqrt(6) * (G21zz - G12zz + G32zz - G23zz + G13zz - G31zz) + + AE_term = factor_convention * factorAE + transferL1_xx = AE_term * Araw_xx + transferL1_xy = AE_term * Araw_xy + transferL1_xz = AE_term * Araw_xz + transferL1_yy = AE_term * Araw_yy + transferL1_yz = AE_term * Araw_yz + transferL1_zz = AE_term * Araw_zz + + transferL2_xx = AE_term * Eraw_xx + transferL2_xy = AE_term * Eraw_xy + transferL2_xz = AE_term * Eraw_xz + transferL2_yy = AE_term * Eraw_yy + transferL2_yz = AE_term * Eraw_yz + transferL2_zz = AE_term * Eraw_zz + + T_term = factor_convention * factorT + transferL3_xx = T_term * Traw_xx + transferL3_xy = T_term * Traw_xy + transferL3_xz = T_term * Traw_xz + transferL3_yy = T_term * Traw_yy + transferL3_yz = T_term * Traw_yz + transferL3_zz = T_term * Traw_zz + + return np.array([transferL1_xx, transferL1_xy, transferL1_xz, transferL1_yy, transferL1_yz, transferL1_zz]), np.array([transferL2_xx, transferL2_xy, transferL2_xz, transferL2_yy, transferL2_yz, transferL2_zz]), np.array([transferL3_xx, transferL3_xy, transferL3_xz, transferL3_yy, transferL3_yz, transferL3_zz]) + +########################################################################################### +# FOR INJECTIONS +########################################################################################### +def create_lisa_injections(hlmf, fmax, fref, beta, lamda, psi, inclination, phi_ref, tref, return_response = False): + print(f"create_lisa_injections function has been called with following arguments:\n{locals()}") + tf_dict, f_dict, amp_dict, phase_dict = get_tf_from_phase_dict(hlmf, fmax, fref) + A = 0.0 + E = 0.0 + T = 0.0 + modes = list(hlmf.keys()) + response = {} + mode_TDI = {} + for mode in modes: + H_0 = transformed_Hplus_Hcross(beta, lamda, psi, inclination, -phi_ref, mode[0], mode[1]) + L1, L2, L3 = Evaluate_Gslr(tf_dict[mode] + tref, f_dict[mode], H_0, beta, lamda) + time_shifted_phase = phase_dict[mode] + 2*np.pi*tref*f_dict[mode] + tmp_data = amp_dict[mode] * np.exp(1j*time_shifted_phase) + # I belive BBHx conjugates because the formalism is define for A*exp(-1jphase), but I need to check with ROS and Mike Katz. + A += np.conj(tmp_data * L1) + E += np.conj(tmp_data * L2) + T += np.conj(tmp_data * L3) + response[mode], mode_TDI[mode] = {}, {} + response[mode]["L1"], response[mode]["L2"], response[mode]["L3"] = np.conj(L1), np.conj(L2), np.conj(L3) + mode_TDI[mode]["L1"], mode_TDI[mode]["L2"], mode_TDI[mode]["L3"] = np.conj(tmp_data*L1), np.conj(tmp_data*L2), np.conj(tmp_data*L3) + A_lal, E_lal, T_lal = create_lal_frequency_series(f_dict[modes[0]], A, hlmf[modes[0]].deltaF), create_lal_frequency_series(f_dict[modes[0]], E, hlmf[modes[0]].deltaF), create_lal_frequency_series(f_dict[modes[0]], T, hlmf[modes[0]].deltaF) + data_dict = {} + data_dict["A"], data_dict["E"], data_dict["T"] = A_lal, E_lal, T_lal + if return_response: + return data_dict, response, mode_TDI + else: + return data_dict + +def generate_lisa_TDI(P_inj, lmax=4, modes=None, tref=0.0, fref=None, return_response=False, path_to_NR_hdf5=None): + print(f"generate_lisa_TDI function has been called with following arguments:\n{locals()}") + P = lsu.ChooseWaveformParams() + + P.m1 = P_inj.m1 + P.m2 = P_inj.m2 + P.s1z = P_inj.s1z + P.s2z = P_inj.s2z + P.dist = P_inj.dist + P.fmin = P_inj.fmin + P.fmax = 0.5/P_inj.deltaT + P.deltaF = P_inj.deltaF + P.deltaT = P_inj.deltaT + + + P.phiref = 0.0 + P.inclination = 0.0 + P.psi = 0.0 + P.fref = P_inj.fref + P.tref = 0.0 + + P.approx = P_inj.approx + hlmf = lisa_lalsimutils_compat.hlmoff_for_LISA( + P, Lmax=lmax, modes=modes, path_to_NR_hdf5=path_to_NR_hdf5 + ) + modes = list(hlmf.keys()) + + # create TDI + output = create_lisa_injections(hlmf, P.fmax, fref, P_inj.theta, P_inj.phi, P_inj.psi, P_inj.incl, P_inj.phiref, tref, return_response) + + if return_response: + return output[0], output[1], output[2] + else: + return output + + +def create_h5_files_from_data_dict(data_dict, save_path): + """This function takes in data dictionary and creates h5 files from them. Assumes the data is stores as COMPLEX16FrequencySeries. + Args: + data_dict (dictonary): contains data for A, E, T channels, + save_path (string): path to where you want to save the h5 files. + Output: + None""" + A_h5_file = h5py.File(f'{save_path}/A-fake_strain-1000000-10000.h5', 'w') + A_h5_file.create_dataset('data', data=data_dict["A"].data.data) + A_h5_file.attrs["deltaF"], A_h5_file.attrs["epoch"], A_h5_file.attrs["length"], A_h5_file.attrs["f0"] = data_dict["A"].deltaF, float(data_dict["A"].epoch), data_dict["A"].data.length, data_dict["A"].f0 + A_h5_file.close() + + E_h5_file = h5py.File(f'{save_path}/E-fake_strain-1000000-10000.h5', 'w') + E_h5_file.create_dataset('data', data=data_dict["E"].data.data) + E_h5_file.attrs["deltaF"], E_h5_file.attrs["epoch"], E_h5_file.attrs["length"], E_h5_file.attrs["f0"] = data_dict["E"].deltaF, float(data_dict["E"].epoch), data_dict["E"].data.length, data_dict["E"].f0 + E_h5_file.close() + + T_h5_file = h5py.File(f'{save_path}/T-fake_strain-1000000-10000.h5', 'w') + T_h5_file.create_dataset('data', data=data_dict["T"].data.data) + T_h5_file.attrs["deltaF"], T_h5_file.attrs["epoch"], T_h5_file.attrs["length"], T_h5_file.attrs["f0"] = data_dict["T"].deltaF, float(data_dict["T"].epoch), data_dict["T"].data.length, data_dict["T"].f0 + T_h5_file.close() + + return None diff --git a/MonteCarloMarginalizeCode/Code/RIFT/LISA/response/__init__.py b/MonteCarloMarginalizeCode/Code/RIFT/LISA/response/__init__.py new file mode 100644 index 000000000..5a7346780 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/LISA/response/__init__.py @@ -0,0 +1,2 @@ +"""Time-dependent LISA response utilities.""" + diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_adaptive_driver.md b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_adaptive_driver.md new file mode 100644 index 000000000..d404ae285 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_adaptive_driver.md @@ -0,0 +1,321 @@ +# Adaptive calibration sampling: driver plan + +Status: **planning** (do not implement the multi-stage loop until we pick a path). + +## Zero-cal burn-in of the extrinsic sampler (proposed; ILE-level, analyze_event) + +**Problem.** Calibration marginalization makes the extrinsic integral much harder to +converge: with cal drawn from the broad PRIOR the per-time lnL has a large dynamic range, +so the adaptive sampler (AV) struggles to reach a useful `n_eff` -- and on the FIRST +iteration there is no learned cal proposal yet, so every cold-start ILE job is in this +regime. Empirically (local DAG run) iteration-0 points come out with `sigma ~ 0.7-0.9` +and few effective samples; high-SNR sources are hard even WITHOUT cal, and cal makes it +worse. We are effectively failing to seed the *intrinsic* grid because the *extrinsic* +sampler never gets going. + +**Idea (O'Shaughnessy).** Burn the sampler in on a *different, cheaper* likelihood first +-- the ZERO-CAL (n_cal=1) baseline -- until it reaches a minimal `n_eff`, so the +extrinsic sampling proposal is "roughly right", THEN switch to the full cal-marginalized +likelihood for the production estimate. The extrinsic posterior shape is nearly the same +with and without cal (cal mostly rescales / mildly shifts lnL), so the burned-in proposal +is an excellent warm start -- and the zero-cal evaluations are ~`n_cal`x cheaper. + +**Where it lives.** `analyze_event` in `integrate_likelihood_extrinsic_batchmode`. The +likelihood closures already read the module-scope `n_cal_for_likelihood`; the sampler also +already supports a warm start via `sampler.update_sampling_prior(..., external_rvs=...)` +(the existing `oracleRS` path, ~line 2078). Two viable mechanisms: + + 1. **Two-phase integrate (simplest).** Set `n_cal_for_likelihood = 1`, call + `sampler.integrate(likelihood_function, ..., n_eff=burn_in_neff)` (the closures now + evaluate the fast zero-cal baseline), then restore `n_cal_for_likelihood` and call + the production `sampler.integrate(...)` WITHOUT resetting -- reusing the adapted + proposal. Risk: AV's reset semantics across two integrate() calls in one + analyze_event are unverified (it "always resets every iteration" between DAG + iterations; need to confirm it does NOT reset at the start of integrate()). + 2. **Warm-start via update_sampling_prior (robust).** Run the zero-cal burn-in, + harvest its drawn extrinsic samples + lnL, and feed them to + `sampler.update_sampling_prior(external_rvs=...)` exactly like the oracle path, then + run the production integrate. Survives regardless of integrate()'s reset behavior. + +**Proposed flag.** `--calibration-burn-in-neff ` (0/None = off): target n_eff for +the zero-cal burn-in (capped by a fraction of n-max). Default off; opt-in. + +**Relation to the bigger seeding plan.** This is the in-job version of the same idea as +the *zero-cal pilots* that seed the intrinsic/extrinsic grids for high-SNR sources: burn +in cheaply, then pay for cal only once the sampling is on-target. The pilot (`pilot.py`, +util_CalPilotStage) seeds the CAL proposal across iterations; this burn-in seeds the +EXTRINSIC proposal within a single job. They compose. + +Status: implemented behind `--calibration-burn-in-neff` (two-phase integrate, toggling +`n_cal_for_likelihood`), correctness-safe (the production integral is always the full cal +one). BUT see the sampler limitation below. + +### Sampler limitation (measured / per review) -- BREADCRUMB for future work +The default and most efficient extrinsic sampler, **AV (mcsamplerAdaptiveVolume), +COMPLETELY RESETS between `integrate()` calls** -- there is currently no seedable AV. So +the two-phase burn-in gives AV **no speedup** (the adapted proposal is thrown away); it is +only ever correctness-safe overhead there. Moreover, **re-seeding AV is inherently +dangerous**: AV can only *contract* its volume, never *expand* or *shift* its boundaries, +so a burn-in proposal that is slightly off / too tight would trap the production phase in +the wrong region. The other integrators (**GMM**, **portfolio**) CAN reuse sampling +models (their `update_sampling_prior` / `gmm_dict` hooks) but are less efficient overall. + +Therefore the zero-cal burn-in only pays off once we have: + 1. a **seedable AV** (warm-start AV from external samples / a prior proposal), and/or + 2. a **more flexible, boundary-shifting AV** that can EXPAND and TRANSLATE its sampling + volume (not only contract) -- so a warm start can't trap it. +Both would be broadly useful well beyond calmarg (every iteration, every warm start). +Until then: keep the burn-in flag (harmless, gated, ready) but do NOT rely on it for AV; +for GMM/portfolio it can warm-start via their model-reuse hooks (untested). The cal PILOT +(across-iteration proposal learning) and the prior-shrinkage backstop remain the load- +bearing pieces for cal; extrinsic seeding waits on seedable AV. + +## Where we are + +- In-loop calmarg works and is validated (loop == fused == reference ~1e-14; CPU+GPU; + default/distmarg; phase-marg). See `DESIGN_calmarg_in_loop.md`. +- **Phase 0** (importance weights, `cal_log_weights`) is wired end-to-end: the + marginalizer computes `Z_cal = sum_c exp(log_w_c) integral L_c / sum_c exp(log_w_c)`. +- **Phase 1 core** (`adaptive.py`) is implemented and unit-tested standalone: a tempered + unimodal-Gaussian proposal in cal spline-node space, importance weights + `w_c = prior/proposal`, fit to the cal posterior, `neff` diagnostics. It needs an + `evaluate(nodes) -> log integral_theta L` callback to run for real. + +The open question is **how to supply that callback in the driver without making the run +expensive** — i.e., how to *learn* the cal proposal during a real analysis. + +## Key facts that shape the choice + +1. **The extrinsic integration is slow even at Lmax=2** (toy problem). So anything that + multiplies the number of full integrations (brute force, multi-stage adaptive) is + costly. Per-likelihood-evaluation timing (`backtest_calmarg.py --scan-ncal`, GPU, + 3 IFO, distmarg, 4096 extrinsic samples, ms/eval): + + | n_cal | reference (brute) | loop (Option B) | fused (Option C) | + |------:|------------------:|----------------:|-----------------:| + | 1 | 57.6 | 57.2 | 57.2 | + | 10 | 571 | 266 | 66.9 | + | 50 | 2854 | 1191 | 198 | + | 100 | 5702 | 2347 | 362 | + | 200 | 11427 | 4662 | 704 | + + The marginal cost of one extra cal realization is ~57 ms (brute), ~23 ms (loop), and + ~3.3 ms (fused) -- fused amortizes the cal axis ~18x better than brute force. A + brute-force reference at n_cal=200 is ~11 s **per likelihood evaluation**; a full + integration is thousands of evaluations, i.e. hours-to-infeasible -> reference only. + Fused at n_cal=200 is ~0.7 s/eval -> a production integration is feasible. Net: + **the production path must be fused**, and we still want to keep n_cal modest via a + learned proposal (the pilot below). +2. **Calibration is boring**: the cal posterior is smooth, unimodal, and — crucially — + nearly **independent of the extrinsic parameters** across the high-likelihood region + (it is set by the data + best-fit template, not by sky/inclination/etc). So we do + NOT need to relearn cal per extrinsic sample, nor iterate many times. +3. The calmarg lnL sitting ABOVE the no-cal baseline is **expected physics, not a bug**: + for cal-on-data, at fixed theta lnL_c = lnL_baseline + (delta.h|h) with mean 0, so + Z_cal(theta) = E_C[L] = L_baseline * exp(+ shift); logmeanexp(lnL_c) > mean(lnL_c) + (dominated by the best-fitting cal draws). Confirmed at the injection point on real + data: mean(lnL_c) ~ baseline, logmeanexp(lnL_c) above it by a positive margin. Small + cal variance -> shift ~ 0.5*Var_c[lnL_c]; high SNR -> larger (best-draw dominated). + This is ALSO why neff_cal collapses and adaptive sampling is needed. + +## Options (and recommendation) + +### A. Brute-force reference (DO — as validation, not production) +Marginalize cal "the hard way": draw a large prior cal set and run the full extrinsic +integral, or sample (theta, cal) jointly with no proposal learning. Cleanest and +unambiguous; the **ground truth** to validate B/C and to settle the baseline-vs-calmarg +question. Slow (cost ~ `n_cal` x single integration), so it is a *reference harness*, +not the production path. Implement as a mode that runs the existing integration with a +large prior cal set and high neff, and compare to the lazy/seeded result. + +### B. Expanding scope: portable (extrinsic + cal) distribution / normalizing flow (LONG TERM) +Capture the learned joint (extrinsic + cal) posterior in a **portable object** to pass +downstream — historically a normalizing flow. This is the decade-old "breadcrumbs" +goal; it has failed before partly because it was bolted on per-integrator and never +standardized. The cal framework sits deep in the core and faces the *same* challenge, +so the right move now is **not** to build a full NF, but to define a clean, +integrator-agnostic **breadcrumb interface**: a small object that can `save`/`load` a +learned proposal (start: Gaussian mean/cov over cal nodes + the importance weights; +later: an NF), with a stable schema. Build the hook; defer the NF. + +### C. Lazy pilot (RECOMMENDED first production path) +Because cal is boring and ~extrinsic-independent, learn it ONCE from a cheap pilot: + +1. Get a handful (K ~ tens) of **high-likelihood extrinsic test points** — e.g. the top-K + by lnL from the first ILE iteration / the proposed grid, or just the best-fit point. +2. At those K points, evaluate the per-cal-realization likelihood **fully** (this is K + cheap evaluations, embarrassingly parallel — "spam in parallel"). Average the + responsibilities over the K points (they agree, since cal is extrinsic-independent). +3. Fit the Gaussian proposal (`adaptive.fit_proposal`, tempered) -> seed the cal nodes. +4. Redraw the run's cal realizations from the proposal; set `cal_log_weights = + prior/proposal` (Phase 0). Run the main integration once with the seeded set. +5. (Optional) one refine pass if `neff_cal` is still low. + +This is a single extra pilot (not a multi-stage loop), exploits cal's boringness, reuses +Phase 0 + Phase 1, and degrades gracefully (if the pilot is poor, importance weights +keep it unbiased — just less efficient). + +## AGREED architecture and priority (do all of A, C, B to prep for the future) + +Priority order **A -> C -> B**: +- **A is the critical benchmark** -- the *only* validation. Build first. +- **C is production** (the parallel-pilot DAG below). +- **B is the future** (portable extrinsic+cal distribution / normalizing flow). Lay + breadcrumbs + stub code now so the plan is remembered. + +This is a deliberate "long jump": more structure than calmarg strictly needs, because +the same machinery generalizes to saving the **extrinsic** distribution (the decade-old +goal). Longer path, but richer payoff and easy to exploit later. + +### Source of pilot points: harvest from the previous iteration's `*.composite` +Every RIFT iteration already produces a `*.composite` of evaluated (intrinsic+extrinsic) +points with their lnL -- plenty of trials, no need for a dedicated pilot integration. +The pilot **harvests the top fraction by lnL (~top 5%)** from iteration N-1's composite +and does full cal there. (This same harvest generalizes to learning the extrinsic +proposal.) + +### Parallel-pilot DAG (nothing serial) +Per iteration N, run in parallel: +- **wide_N**: the normal ILE iteration, with `n_cal` modest, its cal realizations SEEDED + from the consolidated proposal produced after iteration N-1 (importance-weighted, + Phase 0). This is the production likelihood. +- **pilot_N**: harvest top-5% lnL points from iteration N-1's composite; do FULL cal at + those points (large prior `n_cal`, embarrassingly parallel -- "spam in parallel"); + emit a breadcrumb (per-point cal responsibilities / a fitted Gaussian). + +Then a **consolidation_N** job (the barrier between N and N+1) collects the pilot +breadcrumbs into a single consolidated cal proposal (Gaussian mean/cov over cal nodes + +importance-weight bookkeeping). **pilot_N informs wide_{N+1}** through that consolidated +proposal. A **cap** limits how many iterations keep pilot jobs active (once cal is +learned -- it is boring -- freeze the proposal and drop the pilots). + +``` + iter N-1.composite ──► pilot_N ──┐ + ├─► consolidation_N ──► wide_{N+1} (seeded) + (wide_N runs in parallel) ──┘ + (pilots run for the first ~K iterations, then frozen) +``` + +### B (breadcrumbs / future): portable distribution object +The consolidated proposal is a **portable save/load object** with a stable, +integrator-agnostic schema. Start: a Gaussian over cal spline nodes (mean, cov) + the +prior + importance-weight metadata. Designed from the start to ALSO carry an extrinsic +proposal (same harvest->fit->consolidate->seed structure). NF is a later drop-in behind +the same interface. Stub the schema + the consolidation/seed hooks now. + +## n_eff is conservative vs the true ESS (refines the starvation math) +RIFT's reported `n_eff` is a deliberately CONSERVATIVE lower bound -- the true effective +sample size (ESS) is meaningfully larger. So `n_eff(us)=100` yields appreciably more +usable fair-draw points than 100. Consequences: +- The earlier "low n_eff" worry was over-pessimistic: with enough samples the integrator + creeps up fine (the tune-condor run reached n_eff>200 on the moderate-SNR injection), and + the usable ESS is larger still. +- Pilot harvesting is LESS starved than the conservative count implied: top-fraction of + the composite + the larger-than-n_eff ESS means a real run can pull out enough + high-quality points to inform the cal proposal after all. The d(d+1)/2 requirement for + a FULL covariance still holds, but the prior-shrinkage backstop covers the residual + unconstrained directions, so we do not need to fully resolve every cal dof to be safe. + +## Build order (this branch) +1. **Timing data** -- done (`--scan-ncal`). +2. **A: brute-force reference** -- prior-only large-`n_cal`, converged; the ground truth. + Testable now in the backtest: brute-force (large prior set) vs adaptive-seeded must + agree on Z_cal while the seeded run has far higher `neff_cal`. +3. **B-lite breadcrumb I/O** -- `save/load` the cal proposal (Gaussian; schema with an + `extrinsic` slot reserved). Used by C. +4. **C core** -- harvest top-fraction from a `*.composite`; fit (adaptive.fit_proposal); + write/consolidate breadcrumbs; seed the next run's cal realizations. +5. **C DAG wiring** (pilot || wide || consolidation, the cap) in the pipeline builder -- + DONE (opt-in; default DAG byte-identical). See "DAG wiring" below. NEEDS a condor + smoke test on a real cluster run (cannot be exercised off-cluster), like the main-path + GPU end-to-end test. + +## DAG wiring (implemented; opt-in via `--calmarg-pilot`) + +A single per-iteration **calpilot** condor job collapses harvest -> dump -> fit -> +consolidate into one process (`bin/util_CalPilotStage.py`), so the pipeline-builder +surgery is minimal and the steps (which are serial anyway) stay in one place: + +``` + iteration N composite ──► CALPILOT_N (util_CalPilotStage.py): + 1. util_CalHarvestGrid.py top-frac high-lnL pts -> cal_pilot_grid_N.xml.gz + 2. ILE --calibration-dump-responsibilities (cheap: skips the extrinsic sampler) + [+ --calibration-proposal-breadcrumb cal_consolidated_{N-1}.npz -> refine] + 3. util_CalPilotFit.py -> cal_proposal_N.npz (auto-tempered) + 4. util_CalConsolidate.py -> cal_consolidated_N.npz + │ + (CALPILOT_N runs ∥ CIP_N/puff_N; parent = unify_N, does NOT gate them) + ▼ + wide ILE jobs of iteration N+1 --calibration-proposal-breadcrumb cal_consolidated_N.npz + (depend on CALPILOT_N; a missing breadcrumb at early N falls back to the prior) +``` + +- `dag_utils.write_calpilot_sub` defines the job; `create_event_parameter_pipeline_BasicIteration` + instantiates `calpilot_node` per active iteration (parent `unify_node`), records it, and + makes iteration N+1's wide ILE nodes depend on `calpilot_node[N]`. ILE nodes carry a new + `macroiterationprev` macro so the per-iteration breadcrumb path resolves. +- **Cap & cadence**: `--calmarg-pilot-max-it` (default 3), `--calmarg-pilot-cadence` + (default 1) -- pilots stop once cal is learned (cal is boring), freezing the proposal. +- `util_RIFT_pseudo_pipe.py`: `--calmarg-pilot[-cadence|-max-it|-top-fraction|-max-points]` + add the CEPP flags and append the `--calibration-proposal-breadcrumb + .../cal_consolidated_$(macroiterationprev).npz` to the wide ILE args (args_ile.txt). + +Run: add `--calmarg-pilot` to a `util_RIFT_pseudo_pipe.py` invocation that already uses +`--calmarg-envelope-directory ...`. Everything is opt-in; without `--calmarg-pilot` the +DAG and ILE behavior are unchanged. + +NOTE (subdag/exploded-ILE): the seed dependency is wired for the standard ILE batch path; +the `--ile-group-subdag` grouped path would need the dependency placed on the subdag node +(left as a follow-up; uncommon for calmarg runs). + +## Implemented executable decomposition (this branch) + +The pilot/seed loop is realized with two ILE flags + two thin CLIs, all opt-in (the +default DAG and likelihood are byte-identical when unused): + +- `generate_realizations.py` (refactored, prior draws byte-identical): + - `build_realizations_from_nodes(...)` -- spline construction, shared by prior & proposal. + - `node_prior(...)` -- the diagonal-Gaussian cal prior per detector. + - `draw_prior_realizations_with_nodes(...)` -- prior draws that KEEP the node vectors + (cold pilot, N=0). + - `seed_realizations_from_breadcrumb(...) -> (factors, cal_log_weights, nodes)` -- draw + cal realizations from a learned proposal + Phase-0 weights log(prior/proposal). +- `factored_likelihood.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop(..., return_cal_components=True)` + returns the RAW per-realization time-integrated log L `(npts_extrinsic, n_cal)` before + the cal collapse (loop method). Validated: `logsumexp_c(components) - log(n_cal)` == + the cal-marg lnL to ~1e-16. +- ILE `--calibration-proposal-breadcrumb `: seed the run's cal realizations from the + proposal (+ thread `cal_log_weights` to all likelihood call sites & resample). This is + what **wide_{N+1}** uses. +- ILE `--calibration-dump-responsibilities ` (+ `--calibration-pilot-extrinsic`): + the **pilot**. Keeps the cal node draws; at each analyzed (harvested) intrinsic point, + evaluates `return_cal_components` over a uniform-prior extrinsic batch and accumulates + `int dOmega L_c` per realization; writes `(nodes, log_resp=log_w+log int L_c, prior...)`. + If `--calibration-proposal-breadcrumb` is ALSO given, the pilot draws FROM that proposal + (refinement: pilot_N seeded by consolidation_{N-1}) and folds `log_w` into `log_resp`. +- `bin/util_CalPilotFit.py`: pool dumps -> `adaptive.fit_proposal` (AUTO-TEMPERED: pick the + largest beta<=1 whose tempered neff >= `target_neff_frac*n_cal`, so a low-neff cold draw + cannot collapse the proposal) -> breadcrumb. +- `bin/util_CalConsolidate.py`: precision-weighted combine of pilot breadcrumbs (or a + single-input pass-through) -> the consolidated proposal that seeds wide_{N+1}. + +The across-DAG-iteration loop (pilot_N seeded by consolidation_{N-1}, refit, ...) is +exactly `adaptive.adaptive_cal` UNROLLED over RIFT iterations -- no extra serial cost. + +## Convergence characterization (measured) + +The cal node space is high-dimensional (2 * spline_count * n_det; e.g. 60 for 10 nodes x +3 IFOs). A single Gaussian proposal learned from one prior shot in this space converges +SLOWLY when the cal posterior is strongly displaced/narrowed vs the prior: in a stress +test (12 of 60 nodes offset 1 sigma, tightened to 0.5 sigma) the responsibility neff sits +~1-3 and `|mean-true|` only falls to ~0.5 sigma over many rounds -- and the reference +`adaptive.adaptive_cal` behaves the SAME (this is intrinsic to broad-prior importance +sampling in high-D, not a wiring defect). Two things make this acceptable: +1. **Correctness is independent of pilot quality.** The Phase-0 importance weights make + the marginalization UNBIASED for any proposal; a poor pilot only lowers `neff_cal`. +2. **Real cal is boring.** Posteriors are small, smooth, near-prior displacements; in a + benign regime (offset ~0.3 sigma) the prior is already a decent proposal and the pilot + gives a modest neff gain. The big wins are when cal is genuinely informative, where + the across-iteration climb accumulates. +For a sharp high-D posterior the right long-term tool is **B (normalizing flow)** behind +the same breadcrumb interface -- a single Gaussian is the deliberate first cut. diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_calmarg_in_loop.md b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_calmarg_in_loop.md new file mode 100644 index 000000000..3f531a04a --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_calmarg_in_loop.md @@ -0,0 +1,307 @@ +# In-loop calibration marginalization in RIFT ILE + +Branch: `rift_O4d_junior_calmarg_in_loop` (off `rift_O4d_junior_distance`) + +## Motivation + +RIFT currently marginalizes over calibration uncertainty in **postprocessing** +(`bin/calibration_reweighting.py`, bilby-based): after `extrinsic_posterior_samples.dat` +is produced, each sample is reweighted against a set of random calibration draws. +The extrinsic samples entering this step are *not* informed by calibration, so for +high-SNR sources and/or broad calibration priors the reweighting is very inefficient +(most proposed samples get tiny weights). + +Modern GPUs are heavily under-utilized by RIFT's inner loop, so we move the +calibration marginalization **inside ILE**, marginalizing over calibration draws +on-board while the extrinsic likelihood is being evaluated. + +## Key idea: apply calibration to the *data* + +The factored likelihood evaluated by +`factored_likelihood.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop` +combines two quantities: + +* `kappa_sq` — the **data-template** term, built from the GPU Q-product over the + precomputed rholm timeseries `rholmsArrayDict[det] = (t)`. This is the + **only** data-dependent quantity. +* `rho_sq` — the **template-template** cross terms `U,V` (`ctUArrayDict`, + `ctVArrayDict`), ``. These depend on the template and PSD but **not** + on the data. + +If calibration `C(f)` is applied to the **data** (`d -> C(f)·d`), then `rho_sq` is +**calibration-independent** and is computed once; only `kappa_sq` changes per +realization. This is what makes in-loop marginalization cheap. + +> **Convention note for review.** Bilby's `GravitationalWaveTransient` applies the +> calibration factor to the *template/response*, which also rescales the `` +> norm. Applying to the data (our choice) and applying to the template agree to +> first order in the calibration amplitude but differ at second order. The +> apply-to-data choice is what preserves the efficiency win (shared `U,V`). The +> backtest below quantifies the difference against `calibration_reweighting.py`. + +## Data layout + +`RIFT/calmarg/generate_realizations.py::create_realizations` draws `n_cal` complex, +two-sided calibration factors on the full FFT frequency grid (matching +`lalsimutils` packing) from a bilby envelope `.txt` file, shape `(npts_seg, n_cal)`. +Column `c` across detectors is one **joint** draw. + +`ComputeModeIPTimeSeries` (cal branch) applies realization `c` to the data and +concatenates the resulting windowed rholm into one timeseries: + +``` +rholm[det] = [ block_0 | block_1 | ... | block_{n_cal-1} ] length = N_window * n_cal +``` + +`PackLikelihoodDataStructuresAsArrays` carries this long array through unchanged, +so `rholmsArrayDict[det]` has shape `(n_lms, N_window * n_cal)`. Realization `c` is +selected simply by shifting the per-sample window offset: + +``` +ifirst_c = ifirst + c * N_window +``` + +## Marginalization (implemented: Option B) + +We Monte-Carlo marginalize over the `n_cal` draws: + +``` +Z_cal(theta) = (1/n_cal) * sum_c integral dt exp( lnL_t(theta, c) ) +``` + +`DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop` gains an `n_cal` argument. +With `n_cal == 1` the code path is byte-for-byte the original. With `n_cal > 1`: + +1. `rho_sq` is accumulated once in the detector loop (calibration-independent). +2. The per-detector Q-product inputs (`Q`, `FY_conj`, `ifirst`, `N_window`) are cached. +3. For each realization `c`, `kappa` is recomputed via the **existing** GPU kernel + (`Q_inner_product.Q_inner_product_cupy`) with `ifirst + c*N_window`, combined with + the shared `rho_sq` through the same `loglikelihood` callback (distance/phase marg), + and accumulated with a **streaming log-sum-exp** for numerical stability. +4. Finish with `simps` over time and `- log(n_cal)`. + +**Why Option B (cal loop) over the alternatives:** + +| Option | Idea | Memory | Kernel | Review cost | +|---|---|---|---|---| +| A | replicate extrinsic batch ×n_cal, one kernel call | ×n_cal (forces smaller batch) | reused verbatim | lowest LOC | +| **B (chosen)** | loop realizations, reuse kernel, stream log-sum-exp | **unchanged** | reused verbatim, n_cal launches | low | +| C | fused CUDA kernel: Q + loglikelihood + cal-LSE on-board | minimal | new kernel | highest | + +Option B is memory-neutral and reuses the validated kernel — the right +minimum-violence first step given GPUs have spare throughput. + +### Option C (implemented for the default helper) + +`cal_method='fused'` runs a single fused CUDA kernel +(`RIFT/likelihood/cuda_Q_fused_calmarg.cu`, wrapped by +`RIFT/likelihood/Q_fused_calmarg.py`). One thread per extrinsic sample loops over +realizations × time × detectors × modes, forms the data term `kappa`, applies the +default factored-likelihood helper `lnL_t = invDist*Re(kappa) - 0.5*rho_sq`, and +accumulates a streaming, Simpson-weighted log-sum-exp over `(c,t)` — returning +`lnL[j]` directly. No `(batch, n_cal, npts)` intermediate, no per-realization +Python launches. + +Time integration matches Option B exactly by passing the composite-Simpson weight +vector `w_t = simps(I, dx=deltaT)` (simps is linear, so its action is a fixed weight +vector) into the kernel. `rho_sq` is calibration-independent and passed in +pre-summed over detectors. + +**Validated** in the harness vs the brute-force reference and Option B to ~1e-15 on +GPU, single- and multi-detector (H1,L1,V1 — exercises the kernel's detector loop and +the per-detector ifirst stacking). Throughput (NVS 510, sm_30; single synthetic +detector): + +| case | reference | Option B | Option C | +|---|---|---|---| +| n_cal=100, 1024 samples | 695 ms | 170 ms | **22 ms** | +| n_cal=200, 8192 samples | 7080 ms | 2422 ms | **279 ms** | + +i.e. ~8–9× over Option B and ~25–32× over brute force, with bit-level agreement. + +### Option C, stage 2 — distance marginalization (implemented, separate kernel) + +The dominant production path uses the distance-marginalization `loglikelihood` +(sites 1828/1871). This is implemented as a **separate** kernel +(`RIFT/likelihood/cuda_Q_fused_calmarg_distmarg.cu`, wrapper +`Q_fused_calmarg_distmarg_cupy`), kept apart from the default-helper kernel on +purpose: it keeps each kernel's review surface small, leaves the simpler kernel as a +baseline, and leaves `cal_method='loop'` (Option B) as a full fallback for distmarg +on both CPU and GPU. + +It reproduces `distmarg_loglikelihood` exactly on-board: +`x0 = kappa/rho_sq`; `s = asinh(√bmax·(x0−xmin)) − asinh(√bmax·(xmax−x0))`; +`t = asinh(rho_sq/bref)`; bilinear interpolation of `lnI_array` at `(s,t)` (matching +`EvenBivariateLinearInterpolator`, with the same in-bounds mask, contributing 0 +otherwise); plus `exponent_max`. Selected via `cal_method='fused'` **and** passing a +`cal_distmarg` table dict (`lnI_array`, `s0/ds/smin/smax`, `t0/dt/tmax`, +`xmin/xmax/sqrt_bmax/bref`); with `cal_distmarg=None` the default-helper kernel is +used. + +**Validated** in the harness (`--loglikelihood distmarg`, which builds a +self-consistent table and the mirror Python closure for reference/Option B) to +~1e-14 vs the brute-force reference, single- and multi-detector (the asinh/bilinear +differ from numpy only at ULP level). Throughput (sm_30): + +| case (distmarg) | reference | Option B | Option C | +|---|---|---|---| +| n_cal=100, 1024 samp, 2 det | 1364 ms | 495 ms | **77 ms** | +| n_cal=200, 2048 samp, 3 det | 6136 ms | 2358 ms | **333 ms** | + +i.e. ~6–7× over Option B. + +**Scope / limitations of both fused kernels** (raise `NotImplementedError` +otherwise): GPU only; `phase_marginalization=False`; all detectors share +modes/length (true after global mode pruning). + +### Driver wiring (opt-in) + +`integrate_likelihood_extrinsic_batchmode` exposes the fused path behind +`--calibration-fused-kernel` (off by default). When set (and on GPU, with +calibration marginalization active), the driver packages the distance-marginalization +`lookup_table` (`s_array`, `t_array`, `lnI_array`, `bmax`, `bref`) plus `xmin/xmax` +into a `cal_distmarg` dict and passes `cal_method='fused'` at the **non-phase-marg** +distmarg call site. The phase-marg distmarg site and everything else stay on +`cal_method='loop'` (Option B), which remains the default and the fallback for all +cases. On CPU the flag is ignored with a warning (the kernel is GPU-only). + +**End-to-end status.** Run through `integrate_likelihood_extrinsic_batchmode` on the +CI fake data with `--distance-marginalization` + a real `util_InitMargTable` table + +`--calibration-envelope-directory --calibration-fused-kernel`. This caught a real +wiring bug — in the distmarg path `P.dist` is fixed at the fiducial, so `invDistMpc` +is a scalar, but the fused kernel wants one value per extrinsic sample; the fused +branch now broadcasts it to `(npts_extrinsic,)`. After the fix the fused path runs to +completion. Numerics were validated deterministically with +`backtest_calmarg.py --loglikelihood distmarg --real-table `: fused == reference +== loop to ~2e-14 on the production table. (A full *sampler* end-to-end numerical +comparison needs a larger GPU than the local 2 GB card, which OOMs / returns nan under +load.) + +Remaining: a full numerical end-to-end on a larger GPU; phase-marginalization support +in the fused kernels (then the phase-marg distmarg site can opt in too). + +## Driver wiring + +`bin/integrate_likelihood_extrinsic_batchmode` already had the scaffolding: + +* options `--calibration-envelope-directory`, `--calibration-n-realizations`, + `--calibration-spline-count`; +* builds `calibration_realization_dict` and passes it into `PrecomputeLikelihoodTerms`. + +This branch adds `n_cal_for_likelihood` (= `--calibration-n-realizations` when +calibration marginalization is active, else 1) and threads it into the three +production `DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop` call sites +(plain / distance-marg / distance+phase-marg). + +## Bug fixed + +In `ComputeModeIPTimeSeries`'s calibration branch the inner product was being taken +against the *original* `data` instead of the calibration-modified `data_now` +(`IP.ip(hlms[pair], data)` → `IP.ip(hlms[pair], data_now)`), so the calibration +factor was previously never applied. + +## Validation + +`RIFT/calmarg/test_calmarg_reduction.py` builds a synthetic 2-mode, single-detector +case and checks the `n_cal>1` result against a brute-force reference — running the +unchanged `n_cal==1` path on each realization block separately and combining by hand +(`logsumexp_c(lnL_c) - log n_cal`). Agreement is machine precision (~1e-15) on both +the CPU (`xpy=np`) and the production GPU (`xpy=cupy`, real `Q_inner_product` kernel) +paths. It also confirms the `n_cal==1` path is a regression-identical block-0 eval. + +## Backtest harness + +`RIFT/calmarg/backtest_calmarg.py` is the rig **Option C is developed against**. It +holds a `METHODS` registry — `reference` (brute-force per-block + logsumexp), +`in_loop_B` (the `n_cal>1` call), and `in_loop_C` (a stub raising `NotImplementedError` +until the fused kernel exists) — and evaluates each over synthetic inputs that +exercise the cal-block structure, reporting `max|lnL - reference|` and best-of-N +timing on CPU (`--backend cpu`) or GPU (`--backend gpu`). Wire the fused kernel into +`method_in_loop_C` and the harness validates it automatically. + +``` +python -m RIFT.calmarg.backtest_calmarg --backend gpu --n-cal 100 --npts-extrinsic 4096 --repeat 5 +``` + +Current status: `in_loop_B` reproduces `reference` to ~1e-15 on CPU and GPU, with and +without phase marginalization; on GPU it is ~3–4× faster than the brute-force +reference (which redundantly recomputes `rho_sq` per realization — exactly the +redundancy Option C removes). + +`run_physics_backtest()` in the same module is the **scaffold** (docstring + TODOs) for +the heavier real-data comparison vs bilby `calibration_reweighting.py`: load a real +ILE precompute + cal envelopes + the bilby data_dump, evaluate the in-loop calmarg +likelihood on the same extrinsic samples, and compare per-sample lnL and the +log-evidence shift. It needs frames/PSDs, so it runs on the stable host (not in CI). + +## Calibration MC error budget (implemented) + +The sampler's reported variance is the *extrinsic* sampling variance with the cal +draw set held fixed — it is structurally blind to the Monte-Carlo error of the +`(1/n_cal) sum_c` average. Empirically (demo `pp-run`, wide envelopes, `NCAL_DAG=20`) +this produced a 2d lnL surface with ~1.0 point-to-point noise quoted at sigma~0.18 +(chi^2/dof ~ 34 against a smooth surface fit). + +`adaptive.cal_mc_error_from_components(comp, cal_log_weights)` computes, from the +per-realization components on a modest extrinsic-prior batch (`return_cal_components`, +responsibilities are ~extrinsic-independent — same trick as the pilot): + +* `a_c = w_c Z_c / (n_cal Z)` — normalized per-draw contributions (sum to 1); +* `Var(lnZ) ~= n_cal * Var_c(a_c)` (delta method; reproduces the lognormal + `(e^{sigma^2}-1)/n_cal`, validated in `test_cal_mc_error.py`); +* `neff_cal = 1/sum a_c^2` — when `< 10` the estimate is a LOWER BOUND and the + point is flagged in the log. + +The driver folds this in quadrature into the reported sigma column and prints +`[calmarg error] sigma_lnZ: extrinsic X (+) cal Y -> total Z ; cal n_eff ...`. +The probe (`_cal_error_probe`) uses an ADAPTIVE extrinsic batch (doubling until the +estimate stabilizes, capped by `--calibration-mc-error-extrinsic`, default 8192, +0 disables) and draws distance from the RUN'S distance prior: the sampler's own +`prior_pdf['distance']` when distance is sampled (uniform proposal + importance +weight, so the cosmo/redshift variants are handled by construction), the `--d-prior` +pdf when distance marginalization is active, or the PINNED value (warned: at fixed +distance the distance/amplitude degeneracy cannot absorb amplitude-like cal +perturbations, so the estimate is conservative). + +**Adaptive draw count** (`--calibration-neff-cal-target`, default 10; +`--calibration-n-realizations-max`, default 8x initial): after the cal-block +precompute, the same probe measures `neff_cal` at this intrinsic point; while below +target the draw set is DOUBLED — fresh independent draws appended via +`_draw_more_calibration_draws` (extends the realization dict, importance weights, +and node bookkeeping in place), with an incremental `PrecomputeLikelihoodTerms` of +only the new blocks concatenated onto the packed rholm arrays. So +`--calibration-n-realizations` is a *starting* size, not a trusted constant. +`[calmarg adapt]` log lines record the escalation. + +**Sizing guidance** (toy-model scaling, see the paper repo +`demos/calmarg/cal_envelope_scaling.py`): per-draw spread `sigma_lnL ~ rho^2 eps_A` +(amplitude-envelope dominated, ~1.0 per 1% amplitude at network SNR 20), and +`n_cal ~ (e^{sigma_lnL^2}-1)/sigma_target^2`. GWTC-4-scale envelopes (<~2% / <~2 deg) +need `n_cal ~ 100-1000` at SNR 20: **start at 100 and let the adaptive escalation +work; 300 is a comfortable fixed choice**. Beyond ~3% amplitude (or proportionally +higher SNR) prior draws are hopeless; the learned-proposal machinery (pilot / +breadcrumbs) targets that regime but is EXPERIMENTAL — it must be validated against +the brute-force path before being relied on, and is deliberately kept out of the +active/default paths. Memory: realization blocks add ~0.3 MB/draw GPU-resident in +the demo config (88 MB at n_cal=300); per-eval cost is linear in n_cal (fused +kernel: ~0.25 s per 1000-sample chunk at n_cal=300 extrapolating the sm_30 timings +above). + +## Open items / future work + +* **Option C** fused kernel for maximum throughput; backtest vs Option B and vs the + bilby postprocessor on a high-SNR / broad-prior event. +* **Reproducibility:** `create_realizations` uses unseeded `np.random`; add a + `--calibration-seed` so a run's draw set is reproducible. DECISION (2026-06): + workers must KEEP drawing independent sets — common random numbers across + intrinsic points were considered and rejected (a shared draw set makes the lnL + surface artificially smooth and bakes its O(1/sqrt(n_eff_cal)) bias into the + posterior); the variance is instead disclosed via the cal MC error budget above + and beaten down with larger n_cal (now grown adaptively per point). +* **Calibration-parameter export:** Option B does not record which realization was + selected (acceptable per scope — parameter draws can be regenerated at the end as + the current `--dump_cal_realization` path does). +* **Grid sanity asserts:** verify `len(realizations) == data.length` and + `N_window*n_cal == rholm length` explicitly at setup time. +* **CPU + phase-marg + calmarg** uses an explicit einsum mirroring the kernel; covered + by the test for the non-phase-marg case. diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_extrinsic_handoff.md b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_extrinsic_handoff.md new file mode 100644 index 000000000..8fd4f1b04 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/DESIGN_extrinsic_handoff.md @@ -0,0 +1,285 @@ +# Extrinsic handoff: carry the extrinsic posterior between iterations + +Status: **proof-of-concept implemented (GMM)**; AV partial-reset is future work (task #30). + +## The decade-old goal + +RIFT re-solves the *same* extrinsic integral (sky, distance, inclination, polarization, +orbital phase, time) on every intrinsic-grid point, every iteration. But the extrinsic +posterior is set by the data + best-fit template, not by the small intrinsic-grid moves -- +so it barely changes from iteration to iteration. Today each ILE job starts its extrinsic +sampler **cold** (a wide prior proposal) and re-discovers the same sky modes / distance +blob from scratch. This is the long-standing "save the extrinsic distribution to inform +the next iteration" idea: learn the extrinsic proposal once, hand it forward, and let the +next run start *on the answer*. + +This generalizes the calibration pilot's breadcrumb (`RIFT.calmarg.breadcrumbs`, +`RIFT.calmarg.generate_realizations.seed_realizations_from_breadcrumb`): the cal handoff +carries a Gaussian over spline-node parameters; the extrinsic handoff carries a learned +proposal over the extrinsic parameters in the SAME breadcrumb file. + +## GMM-first (implemented) + +RIFT's ensemble sampler (`mcsamplerEnsemble`) is **already seedable**. Its `gmm_dict` +maps parameter GROUPS -- tuples of indices into `params_ordered` -- to a fitted +`gaussian_mixture_model.gmm`; a non-`None` entry is used as the starting proposal and keeps +adapting. The standard groups (set in `analyze_event` ~line 1410) are: + + (right_ascension, declination) # sky + (distance, inclination) # distance/orientation + (phi_orb, psi) # phase/polarization + +So the GMM is the trivially-seedable model to prove the handoff with -- no new sampler +machinery, just pre-fill `gmm_dict`. A normalizing flow (or a seedable AV, below) can drop +in later behind the same fit/seed interface (`extrinsic['kind'] != 'gmm'`). + +### Pieces + +- **`RIFT.calmarg.extrinsic_handoff`** + - `fit_extrinsic_proposal(samples, log_weights, groups=STANDARD_GROUPS, bounds, n_comp=4)` + -- per group, fit RIFT's OWN `gmm.fit` (the exact fitter the sampler uses in + `update_sampling_prior`), using importance weights `lnL + ln prior - ln sampling_prior`. + Returns the portable breadcrumb `extrinsic` dict (per-group means/covs/weights/bounds + + parameter NAMES). GMMs may run on cupy -- inputs are moved via + `model.identity_convert_togpu` before `.fit`. + - `reconstruct_gmm(group, adapt=True)` -- rebuild a `gmm` from a stored group (means/ + covariances/weights restored in the model's internal normalized frame; `adapt=True` + lets the seeded components keep adapting, since the extrinsics drift slightly). + - `gmm_dict_from_breadcrumb(extrinsic, params_ordered, adapt=True)` -- build the + `{dim_group_tuple: gmm}` to seed the next sampler. Dim-groups are looked up by + parameter NAME against this run's `params_ordered`, so the handoff is robust to a + different parameter ordering between runs; groups whose params aren't all present are + skipped silently. + + Using RIFT's own fitter means the stored means/covariances are in exactly the model's + internal (normalized) frame and restore to a byte-identical model -- no coordinate + guesswork. + +- **`RIFT.calmarg.breadcrumbs` (schema v2)** -- the `save`/`load` object gained an + `extrinsic` slot alongside the existing `cal` slot. A breadcrumb can carry cal, extrinsic, + or both. Per group it stores `params`/`means`/`covariances`/`weights`/`bounds`. Schema + is additive (v1 cal-only breadcrumbs still load); bump `SCHEMA_VERSION` on incompatible + changes. + +- **ILE wiring** (`integrate_likelihood_extrinsic_batchmode`, execute-point -- needs a + container rebuild): + - `--extrinsic-proposal-output PATH` -- after `sampler.integrate`, harvest the run's + extrinsic posterior samples + importance weights from `sampler._rvs` (same weight recipe + as the distance-grid export, including the GMM sampler's raw-integrand storage), fit per + group, and `breadcrumbs.save(PATH, extrinsic=...)`. Wrapped in try/except so a + harvest/fit failure can never break a production integration. + - `--extrinsic-proposal-breadcrumb PATH` -- before integration, load the breadcrumb and + pre-fill `gmm_dict` for the matched dim-groups (`gmm_adapt=True`). Missing/unreadable + breadcrumb -> warn and fall back to the cold default. + +### Proof of concept + +`python -m RIFT.calmarg.extrinsic_handoff` builds a synthetic **bimodal** sky posterior + +unimodal distance/inclination blob, fits it, round-trips through a breadcrumb, seeds a fresh +GMM against a *shuffled* `params_ordered`, and confirms the seeded sky GMM reproduces BOTH +sky modes with ~the right mode fractions. `python -m RIFT.calmarg.breadcrumbs` confirms the +cal-Gaussian + extrinsic-GMM coexist and round-trip. Both PASS. + +## Pipeline wiring (implemented) + +The handoff is wired end-to-end through the pipeline, gated by `--extrinsic-handoff` and +**standalone** (it does NOT require the cal pilot -- it works on a plain fused / vanilla run): + +- **`util_RIFT_pseudo_pipe.py --extrinsic-handoff`** adds to `args_ile.txt`: + - `--extrinsic-proposal-output extr_proposal_$(macroiteration)_$(macroevent).npz` -- each + wide ILE job writes its own per-event proposal ($(macroevent) is the per-node macro); + - `--extrinsic-proposal-breadcrumb .../extr_consolidated_$(macroiterationprev).npz` -- the + seed from the previous iteration (OSG: basename + auto-added to the ILE transfer list + + an `extr_consolidated_-1.npz` placeholder for iteration 0; shared FS: absolute path). + It warns if `--ile-sampler-method` is not GMM (the seed is a no-op for other samplers). + +- **`util_ExtrinsicConsolidate.py`** (new) picks the single most representative per-event + proposal (default by lnL -- nearest the peak; `--select neff|n_samples` also available) and + writes `extr_consolidated_.npz`. It ALWAYS writes output (empty if nothing valid), so + the next iteration's seed/transfer never fails; unreadable/placeholder inputs are skipped. + +- **`dag_utils_generic.write_extrconsolidate_sub`** builds the consolidation job in the + **local universe** on the submit node: it is pure-python file selection (no GPU/ILE/ + container/frames), and on OSG the per-event ILE outputs are transferred back to + `/iteration__ile` (ILE's default output transfer), so a local-universe job reads + them from the shared FS with no per-event input transfer (which condor cannot glob). + +- **`create_event_parameter_pipeline_BasicIteration`** creates one consolidation node per + iteration, gated behind that iteration's `unify` node (all ILE done -> per-event proposals + present), and makes iteration N+1's wide ILE jobs depend on the iteration-N consolidation: + unify_{it} -> EXTRCONSOLIDATE_{it} -> wide ILE_{it+1} + (the consolidate barrier and the seed barrier), exactly mirroring the cal-pilot wiring. + +`make extr-build` (demo/rift/calmarg) builds a pipeline with `--extrinsic-handoff +--ile-sampler-method GMM` and validates the whole thread offline (args_ile.txt flags, +EXTRCONSOLIDATE.sub, and the unify->consolidate->next-ILE DAG edges). + +Because cal and extrinsic live in ONE breadcrumb object, a future refinement could ride the +extrinsic proposal on the cal pilot's existing consolidation/transfer instead of a separate +node; the standalone path was chosen first so the handoff works without the (heavier) cal +pilot. The convergence-subdag extension (`--first-iteration-jumpstart`) does not yet carry +`--extrinsic-handoff` -- same limitation as `--calmarg-pilot`. + +## Real-GPU validation (cardassia, NVS 510) and what it taught us + +Ran the full loop interactively on one intrinsic point, GMM sampler + calmarg-fused, on the +CI data: iteration-0 writes `extr_proposal_0_0.npz` -> `util_ExtrinsicConsolidate` picks it +-> iteration-1 ILE loads it and prints `Extrinsic GMM SEEDED ... for dim-groups +[(4,5),(3,2),(0,1)]` (all three standard groups) -> integrates -> writes +`extr_proposal_1_0.npz`. End-to-end the plumbing works on real hardware. Two bugs only the +GPU run surfaced, now fixed: + +1. **bounds left on the host.** `reconstruct_gmm` set means/covs/weights onto the GPU but + left `self.bounds` as numpy. The sampler's `score()`/`_normalize` write into an + `xpy.empty` (cupy) array, so a numpy `self.bounds` raised + `ValueError: non-scalar numpy.ndarray cannot be used for fill`. Fix: `model.bounds = + identity_convert_togpu(bounds)`. +2. **within-group parameter ORDER.** The sampler keys the phase/pol group as + `(psi, phi_orb)=(0,1)` but the breadcrumb stored `(phi_orb, psi)=(1,0)`, so that seed was + silently dropped (key mismatch). Fix: `gmm_dict_from_breadcrumb(existing_keys=...)` matches + each breadcrumb group to the sampler's actual gmm_dict key by dim-SET and permutes the + stored means/covariances/bounds columns into that key's order. + +**Seed quality depends on the SOURCE iteration's convergence.** When the ensemble sampler +hits a bad batch it calls `_reset()`, which sets every `gmm_dict[k]=None` -- i.e. it +**discards the seed and continues cold**. This is the correct safety net: a bad seed is +thrown away, never corrupting the result. In a deliberately tiny smoke (`--n-max 40000` on +the NVS 510 -> iteration-0 `n_eff ~ 1`), the iteration-0 proposal is near-degenerate, so the +seeded first batch produces zero/NaN effective weights and the sampler resets to cold. The +handoff is then correct-but-cosmetic. To see the seed actually ACCELERATE convergence you +need a source iteration that converged reasonably (`n_eff` in the hundreds) -- i.e. a real +`--n-max` (millions) and/or a larger GPU. A modest `cov_inflate` (default 2.0, ~1.4x width) +broadens the seed so the sampler can contract it -- good practice for a warm start, but it +mitigates rather than rescues a genuinely degenerate source. + +## Measured blocker: the GMM sampler does not converge on real sharp ILE peaks + +Trying to demonstrate the seed ACCELERATING convergence on the CI point (network SNR ~17.5, +lnLmax ~ 90-115) surfaced a hard limit of the *seedable* sampler itself, independent of the +handoff and of calibration: + +| config (single CI point, GMM sampler) | n_eff at ~200k samples | +|----------------------------------------------|------------------------| +| GMM + calmarg (n_cal=20) | ~1.0 (256k) | +| GMM, vanilla (no calmarg) | 1.00007 (196k, 50 it) | + +The ensemble (GMM) sampler collapses its mixture onto the single dominant sample at a sharp, +high-SNR peak and then stops improving -- n_eff is pinned at 1 with or without calmarg. +(The AV sampler, by contrast, reached n_eff in the hundreds at a few x10^6 samples in the +earlier calmarg tune runs -- AV's adaptive tessellation handles these peaks; GMM does not.) + +Consequence for the handoff: the GMM->GMM extrinsic handoff is correct and safe, but on real +high-SNR ILE likelihoods the GMM SOURCE iteration never converges to a good proposal, so there +is nothing useful to hand off, and the cold GMM baseline is equally stuck -- there is no +acceleration to measure. The handoff's value is therefore gated on a *seedable sampler that +actually converges*: + - **seedable / partial-reset AV (task #30, #25)** -- the real unlock: AV converges on these + peaks but resets every integrate() and has no seed path. This is now the critical-path + item for making the extrinsic handoff pay off on production data. + - or a **cross-sampler handoff**: converge with AV, fit the GMM to AV's posterior samples + (fit_extrinsic_proposal already does exactly this from any sampler's weighted samples), + and seed a GMM/flow refinement. The save side already accepts arbitrary samples+weights; + only the "harvest AV's _rvs and fit" wiring would be new. + +The handoff plumbing (save -> consolidate -> seed, all groups, GPU-correct) is done and is the +right substrate; the demonstration of speed-up waits on one of the above. + +## Seed adaptation: FREEZE by default (`--extrinsic-proposal-adapt`) + +Re-fitting a seeded GMM group on the first batch is fragile on these likelihoods: with +`adapt=True` the sampler's `_train` calls the GMM fit, whose `_initialize` does +`random.choice(p=weights)` and dies on the pathological first-batch weights +("probabilities are not non-negative") -> `_reset()` -> the seed is discarded. `_train` +already skips groups whose `gmm_adapt[group]` is False, so the ILE seed path now FREEZES the +seeded groups by default (`gmm_adapt=False`); `--extrinsic-proposal-adapt` opts back into +adaptation. Freezing is also the right semantics for a handed-off (especially cross-sampler) +proposal: trust it as-is rather than let GMM's adaptation degrade it. Result: with freeze the +seeded run completes with **0 resets** and the seed actually drives sampling. + +## Cross-sampler AV->GMM seed: partial result, integral still wrong (open) + +Per the chosen plan, converged iteration-0 with **AV** (which does make progress on this +point: n_eff ~7 at 400k, lnLmax ~143), fit the GMM to AV's posterior samples +(`fit_extrinsic_proposal` reads any sampler's `_rvs`), consolidated, and seeded a **frozen** +GMM run: + +- the seed lands cleanly (all 3 groups), **0 resets**, and n_eff rises from the cold ~1 to + **~5-10** -- the seed mechanism is injecting structure. +- BUT the seeded GMM's INTEGRAL is wrong: `sqrt(2 lnLmax)` prints `nan` and Z comes out + ~1e-4 (vs the cold GMM's valid ~1e43 and AV's lnLmax~143). High n_eff in the WRONG region + is worse than honest low n_eff: the frozen proposal is importance-sampling a region that is + consistent-but-displaced from the true posterior. + +Two suspects, not yet isolated (needs a focused audit, no more blind GPU time): +1. **coordinate convention** -- AV vs GMM may store extrinsic samples in `_rvs` under + different conventions (e.g. angle vs cosine for inclination/declination; the sampler adds + `inclination`/`declination` on `[-1,1]` = cosine when `--*-cosine-sampler` is set, but it + is not obvious AV's `_rvs` uses the same). A mismatch would place the fitted GMM in the + wrong frame. Same-sampler GMM->GMM has no such mismatch and round-trips cleanly. +2. **`cov_inflate` out of bounds** -- inflating the seed covariance (x2) can push a sampled + `distance` outside `[1,1000]` (or other hard edges) where the likelihood returns NaN, + contaminating lnLmax. Worth testing `cov_inflate=1` and clipping proposed samples. + +Net: the handoff machinery, the freeze, and the AV-source convergence all work; the +cross-sampler numeric correctness is one debugging session away (audit `_rvs` conventions + +inflation/bounds). The same-sampler GMM->GMM path is already numerically clean -- it just +needs a sampler that converges as a source, i.e. seedable AV (below). + +## Cross-sampler AV->GMM: numerics RESOLVED; benefit gated by GMM convergence + +Debugging the wrong-integral above (per user's steer) found and fixed FOUR real issues in the +save/seed path; the cross-sampler seed is now numerically correct: + +1. **tempered weights (save side).** The GPU/AV sampler (mcsamplerGPU) stores + `_rvs['log_weights'] = tempering_exp*lnL + ln(prior) - ln(s_prior)` -- the adapt-weight- + exponent (e.g. 0.1) baked in. Fitting the GMM to those flattened weights displaces the + proposal. Fix: build the weight from the raw, UNTEMPERED components + (`log_integrand + log_joint_prior - log_joint_s_prior`) and prefer them over `log_weights`. + (GMM's own `_rvs` has no tempering -> GMM->GMM was already fine.) This alone took the + seeded n_eff from ~5 to ~26. +2. **cov_inflate.** Inflating a FROZEN seed only widens it out of bounds; default is now 1.0 + (freeze handles robustness; inflation was for the adapt=True path). +3. **starved fit -> NaN component.** A low-ESS source over-parameterized (n_comp=4 vs few + effective samples) collapses a mixture component to a singular/NaN covariance, and one NaN + component poisons the whole seeded proposal. Fix: cap n_comp by the weight ESS + (`k <= ESS/(d+2)`) and drop any non-finite component (renormalize; skip the group if none + survive). +4. **distance sampled against a hard bound.** The real source of the persistent `nan` lnLmax: + with distance SAMPLED on `[1,1000]`, a seeded distance Gaussian spills past the bound -> + NaN likelihood. Distance marginalization (`--distance-marginalization` + a lookup table + from `util_InitMargTable`) removes distance from the extrinsic sampler entirely; with it on, + the seeded run's lnLmax is finite and the integral is valid. Distmarg is OPTIONAL with the + fused kernel, not required -- the fused kernel has both a non-distmarg kernel + (`Q_fused_calmarg_cupy`) and a distmarg kernel (`Q_fused_calmarg_distmarg_cupy`), and the + ILE binary wires whichever applies. In the pipeline it is `--internal-marginalize-distance` + (which composes cleanly with `--calmarg-fused-kernel`); in the demo it is the `PP_DMARG=1` + toggle. RECOMMENDED with `--extrinsic-handoff` precisely because it removes the distance + dimension + its hard bound from the seeded GMM proposal. + +Measured, distmarg on, single CI point (SNR~17.5), all fixes in: +- AV source converges to n_eff~4.7 (lnLmax~152), writes a clean 2-group (sky, phase/pol) proposal. +- seeded GMM: 0 resets, FINITE lnLmax, VALID integral -- but n_eff ~1.0, ~the same as the + cold GMM (~1.0-1.3). The seed neither helps nor hurts. + +**Conclusion.** The handoff (save -> consolidate -> seed) is now numerically correct and safe +end-to-end on real GPU data. But it does not ACCELERATE on this point because the seedable +sampler (GMM) does not converge here (n_eff~1 cold AND seeded), and the AV source (n_eff~5) is +too under-converged to provide a strongly-informative seed. GMM is seedable but weak; AV +converges but is not seedable. This is now hard evidence that the payoff requires a +**seedable / partial-reset AV (task #30, #25)** -- or a converged source (lower SNR / much +larger sample budget / better GPU) so the GMM seed has real information to carry. The numeric +substrate is done; the win is one of those two regimes away. + +## Why GMM first, and the AV limitation (task #30) + +The adaptive Voronoi sampler (AV, `mcsampler`) is the default extrinsic sampler and is more +efficient, but it **completely resets** between `integrate()` calls -- there is no seed path, +and re-seeding is dangerous because AV can only *contract* its boundaries, never expand or +shift them. So a naive AV warm-start could lock the sampler onto a stale region. The GMM +(and portfolio) samplers reuse sampling models cleanly and are trivially seedable, so they +are the right vehicle for the first working handoff. + +Future work (task #30): a **seedable / partial-reset AV** -- reset only some parameters, or +seed a proposal that AV is allowed to *expand* from, so the more-efficient sampler can also +benefit from the handoff. The breadcrumb `kind` field already leaves room for a non-GMM +model behind the same `save`/`load`/seed interface. diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/adaptive.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/adaptive.py new file mode 100644 index 000000000..eb502050f --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/adaptive.py @@ -0,0 +1,285 @@ +""" +Adaptive calibration sampling (Phase 1). + +Motivation +---------- +In-loop calibration marginalization draws cal realizations from the PRIOR. As SNR +grows the calibration parameters become measurable, so the cal posterior pulls away +from the prior and almost all prior draws land in low-likelihood cal regions -- the +effective number of cal samples collapses. This module learns a unimodal Gaussian +PROPOSAL in cal spline-node space and uses importance weighting (w_c = prior/proposal) +so the marginalized result stays unbiased while the sampling efficiency recovers. + +Tempering +--------- +The per-realization responsibilities (log integral contributions) have a very large +dynamic range at high SNR -- a naive Gaussian fit would be dominated by a single +sample. We fit with TEMPERED weights softmax(beta * log_resp), starting at small beta +(broad, many samples contribute) and ramping beta -> 1 as the proposal narrows onto the +cal posterior. The importance weights used for the *marginalization itself* are always +the full (untempered) w_c = prior/proposal; tempering only shapes the proposal fit. + +This module is backend-agnostic numpy and has no GPU/lal dependency for the learning +machinery itself (it consumes an `evaluate` callback that runs the actual likelihood). +The cal-factor construction reuses the spline convention in generate_realizations. +""" +from __future__ import division + +import numpy as np +import scipy.interpolate +from scipy.special import logsumexp + +from RIFT.calmarg import generate_realizations as _gr + + +# --------------------------------------------------------------------------- +# Prior / node bookkeeping +# --------------------------------------------------------------------------- +def envelope_node_prior(fname, fmin, fmax, n_nodes): + """Per-node Gaussian prior (mean, sigma) for amplitude then phase nodes, and the + log10 spline node frequencies. Node vector layout: [amp_0..amp_{N-1}, ph_0..ph_{N-1}].""" + log_f = np.linspace(np.log10(fmin), np.log10(fmax), n_nodes) + dat_amp, dat_phase = _gr.retrieve_envelope_from_file(fname, frequency_array=10 ** log_f) + mean = np.concatenate([dat_amp[:, 1], dat_phase[:, 1]]) + sigma = np.concatenate([dat_amp[:, 2], dat_phase[:, 2]]) + sigma = np.where(sigma > 0, sigma, 1.0) # guard degenerate (delta) priors + return mean, sigma, log_f + + +def log_prior(nodes, prior_mean, prior_sigma): + """Independent-Gaussian prior log-pdf for each row of `nodes` (n_real, dim).""" + z = (nodes - prior_mean) / prior_sigma + return np.sum(-0.5 * z * z - np.log(prior_sigma * np.sqrt(2 * np.pi)), axis=1) + + +def _mvn_logpdf(nodes, mean, cov): + dim = mean.shape[0] + d = nodes - mean + L = np.linalg.cholesky(cov) + sol = np.linalg.solve(L, d.T).T # (n_real, dim) + quad = np.sum(sol * sol, axis=1) + logdet = 2.0 * np.sum(np.log(np.diag(L))) + return -0.5 * (quad + logdet + dim * np.log(2 * np.pi)) + + +# --------------------------------------------------------------------------- +# Cal-factor construction (spline; matches generate_realizations convention) +# --------------------------------------------------------------------------- +def nodes_to_cal_factors(amp_nodes, phase_nodes, log_f_nodes, T_segment, dT, fmin, fmax): + """Build two-sided complex calibration factors (npts_seg, n_real) from per-node + amplitude/phase values, on the lalsimutils FFT frequency packing. + + amp_nodes, phase_nodes : (n_real, n_nodes) + """ + n_real = amp_nodes.shape[0] + deltaF_seg = 1. / T_segment + npts_seg = int(T_segment / dT) + freq = deltaF_seg * np.array([npts_seg / 2 - k if k <= npts_seg / 2 else -k + npts_seg / 2 + for k in np.arange(npts_seg)]) + mask_in = (np.abs(freq) >= fmin) & (np.abs(freq) <= fmax) + mask_plus = mask_in & (freq > 0) + mask_minus = mask_in & (freq < 0) + lf_pos = np.log10(freq[mask_plus]) + lf_neg = np.log10(-freq[mask_minus]) + + out = np.ones((npts_seg, n_real), dtype=complex) + for i in range(n_real): + cs_a = scipy.interpolate.CubicSpline(log_f_nodes, amp_nodes[i]) + cs_p = scipy.interpolate.CubicSpline(log_f_nodes, phase_nodes[i]) + out[mask_plus, i] = cs_a(lf_pos) * np.exp(1j * cs_p(lf_pos)) + out[mask_minus, i] = cs_a(lf_neg) * np.exp(-1j * cs_p(lf_neg)) + return out + + +# --------------------------------------------------------------------------- +# Tempered proposal fit + diagnostics +# --------------------------------------------------------------------------- +def fit_proposal(nodes, log_resp, beta, cov_floor=1e-8, cov_inflate=1.0, + prior_sigma=None, shrink=None): + """Tempered weighted-Gaussian fit. Weights = softmax(beta * log_resp). + + beta in (0,1]: small -> broad (many samples), 1 -> full responsibility weighting. + + prior_sigma : if given (length-dim 1-sigma of the diagonal prior), SHRINK the fitted + covariance toward diag(prior_sigma**2). This is essential when the fit is starved + -- a weighted sample covariance from ~neff effective points cannot constrain the + dim*(dim+1)/2 entries of a dim-dimensional covariance (cal node space is ~60-D), + so the UNINFORMED directions otherwise collapse to ~0 variance. A near-zero + proposal variance is a near-delta: seeded draws are pinned and the importance + weights log(prior/proposal) blow up, producing the pathological seeded likelihoods + we saw. Shrinking keeps uninformed directions at ~prior width (log_w ~ 0 there). + shrink : explicit shrinkage weight rho in [0,1] toward the prior; default auto = + (dim+1)/(dim+1+neff), i.e. ~1 (all prior) when starved, ->0 (all data) when + neff >> dim. + + Returns (mean, cov).""" + lw = beta * log_resp + lw = lw - logsumexp(lw) + w = np.exp(lw) + mean = w @ nodes + d = nodes - mean + cov = (w[:, None] * d).T @ d + dim = mean.shape[0] + if prior_sigma is not None: + prior_sigma = np.asarray(prior_sigma, dtype=float) + neff = neff_from_logweights(beta * log_resp) + rho = shrink if shrink is not None else (dim + 1.0) / (dim + 1.0 + neff) + rho = float(min(max(rho, 0.0), 1.0)) + cov = (1.0 - rho) * cov_inflate * cov + rho * np.diag(prior_sigma ** 2) + else: + cov = cov_inflate * cov + cov = cov + cov_floor * np.eye(dim) + return mean, cov + + +def neff_from_logweights(log_w): + """Kish effective sample size from log-weights: (sum w)^2 / sum w^2.""" + return float(np.exp(2 * logsumexp(log_w) - logsumexp(2 * log_w))) + + +def cal_mc_error_from_components(comp, cal_log_weights=None, sample_log_weights=None): + """Calibration Monte-Carlo error budget for the cal-marginalized evidence. + + The in-loop marginalization estimates Z = E_c[ w_c Z_c ] over n_cal iid cal + draws, where Z_c = int dtheta p(theta) L(theta, c). The extrinsic sampler's + reported variance CANNOT see the spread over c (the draw set is held fixed for + the whole job), so this term must be estimated separately and added in + quadrature to the extrinsic sampling error. + + comp : (n_samples, n_cal) RAW per-realization time-integrated lnL at a batch of + extrinsic samples (``return_cal_components=True`` output). + cal_log_weights : (n_cal,) importance log-weights log(prior/proposal); + None = prior draws (uniform). + sample_log_weights : (n_samples,) posterior log-weights of the extrinsic batch. + For a batch drawn from the extrinsic PRIOR pass None: the marginal lnL of + each sample (logsumexp_c of comp+cal_log_weights) is then the correct + importance weight. + + Returns (sigma_lnZ_cal, neff_cal, a_c): + a_c : (n_cal,) normalized posterior contribution of realization c, + a_c = w_c Z_c / (n_cal Z); sums to 1. + sigma_lnZ_cal : delta-method standard error of lnZ from the cal MC average, + Var(lnZ) ~= n_cal * Var_c(a_c). (Lognormal cross-check: this + reproduces (exp(sigma_lnL^2)-1)/n_cal.) + neff_cal : Kish size 1 / sum_c a_c^2. When neff_cal is O(1) the + marginalization is dominated by a single draw and the error + estimate itself is a LOWER BOUND -- treat the point as unreliable. + """ + comp = np.atleast_2d(np.asarray(comp, dtype=float)) + n_samples, n_cal = comp.shape + logw = np.zeros(n_cal) if cal_log_weights is None else np.asarray(cal_log_weights, dtype=float) + lc = comp + logw[None, :] # log( w_c L_jc ) + lnL_marg = logsumexp(lc, axis=1) # per-sample log sum_c w_c L_jc (norm cancels) + log_r = lc - lnL_marg[:, None] # responsibilities r_jc, sum_c r_jc = 1 + if sample_log_weights is None: + slw = lnL_marg # prior-drawn batch -> weight by marginal L + else: + slw = np.asarray(sample_log_weights, dtype=float) + slw = slw - logsumexp(slw) # sum_j W_j = 1 + log_a = logsumexp(slw[:, None] + log_r, axis=0) # a_c = sum_j W_j r_jc + a_c = np.exp(log_a - logsumexp(log_a)) # exact renormalization + var_lnZ = n_cal * np.var(a_c, ddof=1) if n_cal > 1 else 0.0 + neff_cal = 1.0 / np.sum(a_c ** 2) + return float(np.sqrt(max(var_lnZ, 0.0))), float(neff_cal), a_c + + +# --------------------------------------------------------------------------- +# Adaptive loop +# --------------------------------------------------------------------------- +def adaptive_cal(evaluate, prior_mean, prior_sigma, n_nodes_amp, + n_real=200, n_iter=4, betas=None, rng=None, return_history=False): + """Run the adaptive cal-sampling loop. + + evaluate(nodes) -> log_L : callback returning, for each realization (row of + `nodes`), the extrinsic-marginalized log-likelihood log integral_theta + L(theta, cal(nodes_c)) -- NO prior, NO importance weight (the loop folds those + in). In practice `evaluate` builds the cal factors (nodes_to_cal_factors) and + runs the ILE integral per realization. + + The per-realization posterior responsibility (used to fit the next proposal and to + measure efficiency) is log_w + log_L = log( prior(c) * integral L / proposal(c) ), + i.e. posterior/proposal; neff of these -> n_real exactly when the proposal matches + the cal posterior. The final `log_w` are the importance weights for the + marginalization itself ( Z_cal = sum_c exp(log_w_c) integral L_c ). + + Returns dict with the final realizations' `nodes`, `log_w` (prior/proposal, for the + marginalization), `proposal` (mean,cov), and per-iteration `neff` history. + """ + rng = rng or np.random.default_rng() + dim = prior_mean.shape[0] + if betas is None: + # ramp tempering 0.3 -> 1.0 + betas = np.linspace(0.3, 1.0, n_iter) + mean = prior_mean.copy() + cov = np.diag(prior_sigma ** 2) + + history = [] + nodes = log_w = None + for it in range(n_iter): + nodes = rng.multivariate_normal(mean, cov, size=n_real) # (n_real, dim) + log_q = _mvn_logpdf(nodes, mean, cov) + log_p = log_prior(nodes, prior_mean, prior_sigma) + log_w = log_p - log_q # importance weights + log_L = np.asarray(evaluate(nodes)) # extrinsic-marg log-like + log_resp = log_w + log_L # posterior/proposal + # next proposal from tempered posterior responsibilities; inflate the covariance + # early (while tempering is on) to keep exploring, relax as beta -> 1. + beta = float(betas[min(it, len(betas) - 1)]) + mean, cov = fit_proposal(nodes, log_resp, beta, cov_inflate=1.0 + (1.0 - beta), + prior_sigma=prior_sigma) + neff_resp = neff_from_logweights(log_resp) + neff_w = neff_from_logweights(log_w) + history.append(dict(iter=it, beta=beta, neff_resp=neff_resp, neff_w=neff_w)) + + out = dict(nodes=nodes, log_w=log_w, proposal_mean=mean, proposal_cov=cov, + history=history) + if return_history: + out['history'] = history + return out + + +# --------------------------------------------------------------------------- +# Self-contained convergence demo (mock likelihood): no GPU/lal needed +# --------------------------------------------------------------------------- +if __name__ == "__main__": + # A "true" calibration sits ~3 sigma off the prior mean in node space, with a + # narrow likelihood (high SNR -> measurable cal). Prior-only sampling would have + # tiny neff; the adaptive loop should lock onto it and neff should climb. + rng = np.random.default_rng(1234) + dim = 8 + prior_mean = np.zeros(dim) + prior_sigma = np.ones(dim) + # measurable cal ~2 sigma off the prior, narrow likelihood (high SNR) + true_node = prior_mean + 2.0 * prior_sigma * rng.standard_normal(dim) / np.sqrt(dim) + like_sigma = 0.4 + + def evaluate(nodes): + # extrinsic-marginalized log-like proxy (no prior, no weights -- the loop adds them) + z = (nodes - true_node) / like_sigma + return -0.5 * np.sum(z * z, axis=1) + + # analytic cal posterior (Gaussian prior x Gaussian like): mean pulled from `true` + # toward the prior mean; this is the target the proposal should converge to. + w_like = 1.0 / like_sigma ** 2 + w_prior = 1.0 / prior_sigma ** 2 + post_mean = (true_node * w_like + prior_mean * w_prior) / (w_like + w_prior) + + # prior-only baseline: neff of the posterior responsibilities prior*L/prior = L + base = rng.multivariate_normal(prior_mean, np.diag(prior_sigma ** 2), size=300) + base_neff = neff_from_logweights(evaluate(base)) + err0 = float(np.max(np.abs(post_mean - prior_mean))) + print("prior-only neff_resp = %.1f / 300 (posterior is %.2f sigma off the prior mean)" + % (base_neff, err0)) + + res = adaptive_cal(evaluate, prior_mean, prior_sigma, n_nodes_amp=dim // 2, + n_real=300, n_iter=6, rng=rng) + for h in res['history']: + print("iter %d beta=%.2f neff_resp=%6.1f neff_w=%6.1f" % ( + h['iter'], h['beta'], h['neff_resp'], h['neff_w'])) + err = float(np.max(np.abs(res['proposal_mean'] - post_mean))) + print("proposal mean vs cal posterior: max|delta| = %.3f sigma" % err) + assert res['history'][-1]['neff_resp'] > 10 * base_neff, \ + "adaptive did not improve effective cal sample size" + assert err < 0.3, "proposal did not converge onto the cal posterior" + print("\nPASS: tempered adaptive cal sampling converges onto the cal posterior " + "and recovers effective samples.") diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/backtest_calmarg.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/backtest_calmarg.py new file mode 100644 index 000000000..f0acb42b8 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/backtest_calmarg.py @@ -0,0 +1,536 @@ +""" +backtest_calmarg.py -- backtest harness for in-loop calibration marginalization + +PURPOSE +------- +Compare different implementations of the calibration-marginalized factored +likelihood against each other and against a brute-force reference, on controlled +inputs that exercise the per-realization block structure +(rholmsArrayDict[det] holding n_cal contiguous length-N_window blocks, selected by +ifirst -> ifirst + c*N_window). + +This is the rig Option C (a fused CUDA kernel) is developed against: register the +new implementation in METHODS and the harness reports lnL agreement (vs the +brute-force reference and vs Option B) and timing, on both CPU and GPU backends. + +It is deliberately self-contained (synthetic inputs, no frames/PSDs/cache needed), +so it runs anywhere RIFT + lal import. See run_physics_backtest() below for the +heavier real-data comparison vs bilby's calibration_reweighting.py. + +METHODS (the registry being backtested) +---------------------------------------- + reference : brute force -- run the unchanged n_cal==1 likelihood on each + realization block separately, combine logsumexp_c(lnL_c) - log(n_cal). + This is the ground truth the others must reproduce. + in_loop_B : DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop(..., n_cal=n_cal) + -- Option B (cal loop reusing the existing Q kernel, streaming LSE). + in_loop_C : Option C fused kernel -- STUB, raises NotImplementedError. Wire the + new implementation here when it exists; the harness will validate it. + +USAGE +----- + python -m RIFT.calmarg.backtest_calmarg --backend cpu --n-cal 20 + python -m RIFT.calmarg.backtest_calmarg --backend gpu --n-cal 100 --npts-extrinsic 4096 --repeat 5 + python -m RIFT.calmarg.backtest_calmarg --backend gpu --methods reference,in_loop_B,in_loop_C +""" +from __future__ import print_function + +import argparse +import time + +import numpy as np +import lal +from scipy.special import logsumexp + +import RIFT.likelihood.factored_likelihood as fl + + +# --------------------------------------------------------------------------- +# Synthetic case construction +# --------------------------------------------------------------------------- +def make_synthetic_case(n_cal=20, npts_extrinsic=64, N_window=256, npts=16, + deltaT=1.0/4096, dets=("H1", "L1"), seed=1234, + psd_UV=False): + """Build a controlled set of likelihood inputs with embedded cal-block + structure. The n_cal realization blocks are independent random rholm draws + (their physical relationship is irrelevant for backtesting that the *reduction* + over blocks is computed correctly -- the loglikelihood callback is applied + identically across methods, so method agreement holds regardless). + + Multiple detectors exercise the kernel's detector loop and the function's + per-detector stacking; each detector gets its own random rholms/U/V, and the + likelihood derives a distinct per-detector ifirst from the (real) detector + location, so the stacked-ifirst path is genuinely tested. + + N_window must exceed the sky time-delay spread (+-0.021 s) plus npts so the + per-sample window stays inside each block. + + Returns a dict ('case') of plain numpy arrays / scalars (rholms/U/V are dicts + keyed by detector); convert to a backend in the method functions. + """ + rng = np.random.default_rng(seed) + n_lms = 2 + npts_full = N_window * n_cal + dets = tuple(dets) + + case = dict( + dets=dets, n_cal=n_cal, n_lms=n_lms, N_window=N_window, npts=npts, + deltaT=deltaT, npts_extrinsic=npts_extrinsic, + lookupNK=np.array([[2, 2], [2, -2]], dtype=int), + tref=1000000000.0, + ) + case["rholms"] = {} + case["U"] = {} + case["V"] = {} + for det in dets: + case["rholms"][det] = (rng.standard_normal((n_lms, npts_full)) + + 1j*rng.standard_normal((n_lms, npts_full))) + if psd_UV: + # positive-definite U, V=0 -> rho_sq>0, required by the distmarg + # transforms (asinh(rho_sq/bref), x0=kappa/rho_sq); mirrors physical + case["U"][det] = np.eye(n_lms, dtype=complex) + case["V"][det] = np.zeros((n_lms, n_lms), dtype=complex) + else: + U = rng.standard_normal((n_lms, n_lms)) + 1j*rng.standard_normal((n_lms, n_lms)) + V = rng.standard_normal((n_lms, n_lms)) + 1j*rng.standard_normal((n_lms, n_lms)) + case["U"][det] = U + U.conj().T + case["V"][det] = V + V.conj().T + # epoch placed so the integration window sits near the middle of each block + case["epoch"] = case["tref"] - 0.03 + case["tvals"] = np.linspace(-(npts//2)*deltaT, (npts//2)*deltaT, npts) + + # extrinsic parameter arrays + case["phi"] = rng.uniform(0, 2*np.pi, npts_extrinsic) + case["theta"] = rng.uniform(0.2, np.pi-0.2, npts_extrinsic) + case["psi"] = rng.uniform(0, np.pi, npts_extrinsic) + case["incl"] = rng.uniform(0.2, np.pi-0.2, npts_extrinsic) + case["phiref"] = rng.uniform(0, 2*np.pi, npts_extrinsic) + case["dist"] = np.full(npts_extrinsic, 500.0) * (lal.PC_SI*1e6) # 500 Mpc + return case + + +class _PVec(object): + """Minimal stand-in for the vectorized ChooseWaveformParams object that the + likelihood reads (phi, theta, psi, incl, phiref, dist arrays; tref, deltaT + scalars).""" + pass + + +def _build_P(case, xpy): + P = _PVec() + for name in ("phi", "theta", "psi", "incl", "phiref", "dist"): + setattr(P, name, xpy.asarray(case[name])) + P.tref = case["tref"] + P.deltaT = case["deltaT"] + return P + + +def _backend(name): + if name == "cpu": + return np + if name == "gpu": + import cupy as cp + return cp + raise ValueError("backend must be 'cpu' or 'gpu', got %r" % name) + + +def _to_host(x): + try: + import cupy as cp + if isinstance(x, cp.ndarray): + return cp.asnumpy(x) + except ImportError: + pass + return np.asarray(x) + + +def _dicts(case, xpy, rholms): + """Build the per-detector dicts the likelihood expects from a rholms map.""" + dets = case["dets"] + lookupNKDict = {d: case["lookupNK"] for d in dets} + rholmsArrayDict = {d: xpy.asarray(rholms[d]) for d in dets} + ctU = {d: xpy.asarray(case["U"][d]) for d in dets} + ctV = {d: xpy.asarray(case["V"][d]) for d in dets} + epochDict = {d: case["epoch"] for d in dets} + return lookupNKDict, rholmsArrayDict, ctU, ctV, epochDict + + +def _block_rholms(case, c): + """Per-detector rholms restricted to realization block c.""" + N = case["N_window"] + return {d: case["rholms"][d][:, c*N:(c+1)*N] for d in case["dets"]} + + +# --------------------------------------------------------------------------- +# Distance-marginalization table + loglikelihood (mirror of the ILE driver, so the +# fused distmarg kernel can be validated against reference/Option B using the SAME +# table and transforms) +# --------------------------------------------------------------------------- +def _bilinear(s0, ds, t0, dt, fgrid, xpy): + """Mirror of EvenBivariateLinearInterpolator in the ILE driver.""" + dx_inv, dy_inv = 1.0/ds, 1.0/dt + + def call(x, y): + i_mid = dx_inv * (x - s0) + j_mid = dy_inv * (y - t0) + i_lo = xpy.floor(i_mid).astype(int); i_hi = xpy.ceil(i_mid).astype(int) + j_lo = xpy.floor(j_mid).astype(int); j_hi = xpy.ceil(j_mid).astype(int) + p = i_mid - i_lo; q = j_mid - j_lo + p_ = 1 - p; q_ = 1 - q + f = p_*q_ * fgrid[i_lo, j_lo] + f += p*q_ * fgrid[i_hi, j_lo] + f += p_*q * fgrid[i_lo, j_hi] + f += p*q * fgrid[i_hi, j_hi] + return f + return call + + +def make_distmarg_table(xpy, ns=64, nt=48, xmin=-1.0e4, xmax=1.0e4, + sqrt_bmax=1.0, bref=1.0, tmax=10.0, seed=7): + """Build a synthetic-but-self-consistent distance-marginalization table. + + s_array spans x0_to_s(xmin)..x0_to_s(xmax), so any x0 in (xmin,xmax) maps to an + in-bounds s; wide (xmin,xmax) keeps realized x0=kappa/rho_sq in range. lnI_array + is an arbitrary smooth surface -- physical values are irrelevant for backtesting + that the kernel reproduces the same transform the Python closure applies. + """ + def x0_to_s(x0): + return (np.arcsinh(sqrt_bmax*(x0 - xmin)) + - np.arcsinh(sqrt_bmax*(xmax - x0))) + smin = float(x0_to_s(xmin)) + smax = float(x0_to_s(xmax)) + s_array = np.linspace(smin, smax, ns) + t_array = np.linspace(0.0, tmax, nt) + SS, TT = np.meshgrid(s_array, t_array, indexing='ij') + lnI_array = -0.3*SS**2 + np.cos(TT) - 0.05*TT # smooth, arbitrary + + return dict( + lnI_array=xpy.asarray(lnI_array), + s0=float(s_array[0]), ds=float(s_array[1]-s_array[0]), + smin=float(s_array[0]), smax=float(s_array[-1]), + t0=float(t_array[0]), dt=float(t_array[1]-t_array[0]), + tmax=float(t_array[-1]), + xmin=float(xmin), xmax=float(xmax), + sqrt_bmax=float(sqrt_bmax), bref=float(bref), + ) + + +def load_real_distmarg_table(npz_path, xpy, dmin=1.0, dmax=1000.0): + """Load a real util_InitMargTable .npz into the same params dict the kernel + + mirror closure consume. Lets us backtest against the production table's actual + s/t ranges (e.g. t_array[0] may be > 0) deterministically.""" + import RIFT.likelihood.factored_likelihood as _fl + d = np.load(npz_path) + s_array = np.asarray(d["s_array"]); t_array = np.asarray(d["t_array"]) + bmax = float(np.asarray(d["bmax"])); bref = float(np.asarray(d["bref"])) + return dict( + lnI_array=xpy.asarray(d["lnI_array"]), + s0=float(s_array[0]), ds=float(s_array[1]-s_array[0]), + smin=float(s_array[0]), smax=float(s_array[-1]), + t0=float(t_array[0]), dt=float(t_array[1]-t_array[0]), + tmax=float(t_array[-1]), + xmin=float(_fl.distMpcRef/dmax), xmax=float(_fl.distMpcRef/dmin), + sqrt_bmax=float(np.sqrt(bmax)), bref=bref, + ) + + +def make_distmarg_loglikelihood(params, xpy): + """Python distmarg loglikelihood closure (mirror of the ILE driver), consuming + the same table the fused kernel uses.""" + xmin, xmax = params["xmin"], params["xmax"] + sqrt_bmax, bref = params["sqrt_bmax"], params["bref"] + smin, smax, tmax = params["smin"], params["smax"], params["tmax"] + intp = _bilinear(params["s0"], params["ds"], params["t0"], params["dt"], + params["lnI_array"], xpy) + + def loglikelihood(kappa_sq, rho_sq): + x0 = kappa_sq / rho_sq + s = (xpy.arcsinh(sqrt_bmax*(x0 - xmin)) + - xpy.arcsinh(sqrt_bmax*(xmax - x0))) + t = xpy.arcsinh(rho_sq / bref) + lnI = xpy.full_like(x0, -xpy.inf) + in_bounds = (s > smin) & (s < smax) & (t < tmax) + lnI[in_bounds] = intp(s[in_bounds], t[in_bounds]) + x0c = xpy.clip(x0, xmin, xmax) + return rho_sq * x0c * (x0 - 0.5*x0c) + lnI + return loglikelihood + + +# --------------------------------------------------------------------------- +# Method implementations (the registry being backtested) +# --------------------------------------------------------------------------- +def method_reference(case, xpy, phase_marginalization=False, loglikelihood=None): + """Brute force: per-block n_cal==1 evaluation, combined by hand.""" + if loglikelihood is None: + loglikelihood = fl._factored_lnL_helper + P = _build_P(case, xpy) + tvals = xpy.asarray(case["tvals"]) + n_cal = case["n_cal"] + lnL_blocks = np.zeros((n_cal, case["npts_extrinsic"])) + for c in range(n_cal): + lookupNKDict, rholmsArrayDict, ctU, ctV, epochDict = _dicts( + case, xpy, _block_rholms(case, c)) + out = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + tvals, P, lookupNKDict, rholmsArrayDict, ctU, ctV, + epochDict, Lmax=2, xpy=xpy, n_cal=1, + loglikelihood=loglikelihood, phase_marginalization=phase_marginalization) + lnL_blocks[c] = _to_host(out) + lw = case.get("cal_log_weights") + if lw is None: + return logsumexp(lnL_blocks, axis=0) - np.log(n_cal) + # unbiased importance estimate: (1/n_cal) sum_c w_c L_c -> normalize by log(n_cal) + lw = np.asarray(lw, dtype=float) + return logsumexp(lnL_blocks + lw[:, None], axis=0) - np.log(n_cal) + + +def method_in_loop_B(case, xpy, phase_marginalization=False, loglikelihood=None): + """Option B: single call with n_cal>1 (cal_method='loop').""" + if loglikelihood is None: + loglikelihood = fl._factored_lnL_helper + P = _build_P(case, xpy) + lookupNKDict, rholmsArrayDict, ctU, ctV, epochDict = _dicts( + case, xpy, case["rholms"]) + out = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + xpy.asarray(case["tvals"]), P, lookupNKDict, rholmsArrayDict, ctU, ctV, + epochDict, Lmax=2, xpy=xpy, n_cal=case["n_cal"], cal_method='loop', + cal_log_weights=case.get("cal_log_weights"), + loglikelihood=loglikelihood, phase_marginalization=phase_marginalization) + return _to_host(out) + + +def method_in_loop_C(case, xpy, phase_marginalization=False, loglikelihood=None): + """Option C: fused CUDA kernel (Q + default helper + cal log-sum-exp on-board). + + GPU-only and (for now) default helper / no phase marginalization; raises + NotImplementedError otherwise, so the harness SKIPs it on CPU. + """ + if loglikelihood is None: + loglikelihood = fl._factored_lnL_helper + P = _build_P(case, xpy) + lookupNKDict, rholmsArrayDict, ctU, ctV, epochDict = _dicts( + case, xpy, case["rholms"]) + out = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + xpy.asarray(case["tvals"]), P, lookupNKDict, rholmsArrayDict, ctU, ctV, + epochDict, Lmax=2, xpy=xpy, n_cal=case["n_cal"], cal_method='fused', + cal_distmarg=case.get("cal_distmarg"), + cal_log_weights=case.get("cal_log_weights"), + loglikelihood=loglikelihood, phase_marginalization=phase_marginalization) + return _to_host(out) + + +METHODS = { + "reference": method_reference, + "in_loop_B": method_in_loop_B, + "in_loop_C": method_in_loop_C, +} + + +# --------------------------------------------------------------------------- +# Comparison driver +# --------------------------------------------------------------------------- +def _sync(xpy): + if xpy is not np: + xpy.cuda.Stream.null.synchronize() + + +def run_backtest(methods, backend="cpu", repeat=3, phase_marginalization=False, + loglikelihood_mode="default", real_table=None, + random_cal_weights=False, **case_kwargs): + """Evaluate each method, time it, and report agreement vs 'reference' (if run) + and vs 'in_loop_B'. + + loglikelihood_mode: + 'default' -- the distance-unmarginalized helper. + 'distmarg' -- the distance-marginalization loglikelihood (uses positive-definite + U so rho_sq>0, builds a self-consistent table; reference/Option B + use the Python closure, Option C uses the fused distmarg kernel). + """ + xpy = _backend(backend) + loglikelihood = None + if loglikelihood_mode == "distmarg": + case_kwargs["psd_UV"] = True + case = make_synthetic_case(**case_kwargs) + case["dist"] = np.full(case["npts_extrinsic"], fl.distMpcRef) * (lal.PC_SI*1e6) + params = (load_real_distmarg_table(real_table, xpy) if real_table + else make_distmarg_table(xpy)) + case["cal_distmarg"] = params # consumed by the fused distmarg kernel + loglikelihood = make_distmarg_loglikelihood(params, xpy) + else: + case = make_synthetic_case(**case_kwargs) + if random_cal_weights: + # non-uniform importance log-weights, to validate the weighted reduction + rng = np.random.default_rng(20240601) + case["cal_log_weights"] = rng.normal(0.0, 1.5, size=case["n_cal"]) + # distmarg's asinh/bilinear differ at ULP level between numpy and the kernel, so + # the fused-vs-loop agreement is float-level rather than bit-level. + tol = 1e-9 if loglikelihood_mode == "default" else 1e-6 + print("# calmarg backtest backend=%s dets=%s n_cal=%d npts_extrinsic=%d N_window=%d npts=%d phase_marg=%s loglike=%s" + % (backend, ",".join(case["dets"]), case["n_cal"], case["npts_extrinsic"], + case["N_window"], case["npts"], phase_marginalization, loglikelihood_mode)) + + results = {} + timings = {} + for name in methods: + fn = METHODS[name] + try: + out = fn(case, xpy, phase_marginalization=phase_marginalization, + loglikelihood=loglikelihood) # warm-up / compile + _sync(xpy) + best = float("inf") + for _ in range(repeat): + t0 = time.perf_counter() + out = fn(case, xpy, phase_marginalization=phase_marginalization, + loglikelihood=loglikelihood) + _sync(xpy) + best = min(best, time.perf_counter() - t0) + results[name] = np.asarray(out) + timings[name] = best + print(" %-12s ok best %8.2f ms" % (name, best*1e3)) + except NotImplementedError as e: + print(" %-12s SKIP (%s)" % (name, e)) + except Exception as e: + print(" %-12s FAIL %s: %s" % (name, type(e).__name__, e)) + + # agreement + baseline = "reference" if "reference" in results else ( + "in_loop_B" if "in_loop_B" in results else None) + if baseline: + print("# max |lnL - %s| (tol %.0e):" % (baseline, tol)) + ok = True + for name, vals in results.items(): + if name == baseline: + continue + err = float(np.max(np.abs(vals - results[baseline]))) + flag = "OK" if err < tol else "**DIFF**" + if err >= tol: + ok = False + print(" %-12s %.3e %s" % (name, err, flag)) + print("# RESULT:", "PASS" if ok else "MISMATCH") + return ok + return True + + +# --------------------------------------------------------------------------- +# Physics backtest vs bilby calibration_reweighting.py (scaffold -- needs data) +# --------------------------------------------------------------------------- +def run_physics_backtest(precompute_or_config=None, cal_envelope_dir=None, + bilby_data_dump=None, **kwargs): + """Compare in-loop calibration marginalization to the bilby postprocessor on a + REAL event. This needs frames/PSDs/cache (or a saved ILE precompute) plus the + bilby data_dump used by calibration_reweighting.py, so it does NOT run in the + self-contained harness above. + + Intended flow (TODO, to run on the stable host): + 1. Build data_dict / psd_dict (real or injected) the same way ILE does, OR + load a saved precompute. + 2. cal = RIFT.calmarg.generate_realizations.create_realizations(env, ...) for + each detector from cal_envelope_dir. + 3. PrecomputeLikelihoodTerms(..., calibration_realizations=cal) -> cal-extended + rholms; pack with PackLikelihoodDataStructuresAsArrays. + 4. Evaluate the in-loop calmarg likelihood over the SAME extrinsic samples the + bilby reweighter used (read its posterior + weights), compare per-sample + lnL and the integrated log-evidence shift. + 5. Compare to bilby calibration_likelihood from calibration_reweighting.py. + Expect agreement to first order in cal amplitude; the apply-to-data + (RIFT) vs apply-to-template (bilby) convention differs at second order -- + quantify and record that difference here. + """ + raise NotImplementedError( + "Physics backtest needs real data/precompute + a bilby data_dump; " + "see docstring for the intended flow. Run on the stable host post-update.") + + +def scan_timing(methods, backend="gpu", n_cal_list=(1, 10, 50, 100, 200), + repeat=5, loglikelihood_mode="default", real_table=None, **case_kwargs): + """Per-likelihood-evaluation wall-time vs n_cal, to quantify the cost of + calibration marginalization (and the brute-force reference) for planning. + + Reports best-of-`repeat` ms per call for each method at each n_cal. The + reference (brute force) does n_cal separate n_cal==1 evaluations, so its cost + scales ~linearly in n_cal; loop reuses the kernel per realization; fused does it + in one launch. Multiply by (n_iterations * blocks-per-iteration) to estimate the + full-integration cost.""" + xpy = _backend(backend) + print("# timing scan backend=%s dets=%s npts_extrinsic=%d loglike=%s" + % (backend, case_kwargs.get("dets", "H1,L1"), + case_kwargs.get("npts_extrinsic", 64), loglikelihood_mode)) + print("# %-6s " % "n_cal" + "".join("%14s" % m for m in methods) + " (ms/eval, best of %d)" % repeat) + for n_cal in n_cal_list: + ck = dict(case_kwargs); ck["n_cal"] = n_cal + loglikelihood = None + if loglikelihood_mode == "distmarg": + ck["psd_UV"] = True + case = make_synthetic_case(**ck) + case["dist"] = np.full(case["npts_extrinsic"], fl.distMpcRef) * (lal.PC_SI*1e6) + params = (load_real_distmarg_table(real_table, xpy) if real_table + else make_distmarg_table(xpy)) + case["cal_distmarg"] = params + loglikelihood = make_distmarg_loglikelihood(params, xpy) + else: + case = make_synthetic_case(**ck) + row = [] + for name in methods: + fn = METHODS[name] + try: + fn(case, xpy, loglikelihood=loglikelihood) # warm-up + _sync(xpy) + best = float("inf") + for _ in range(repeat): + t0 = time.perf_counter() + fn(case, xpy, loglikelihood=loglikelihood) + _sync(xpy) + best = min(best, time.perf_counter() - t0) + row.append("%14.3f" % (best * 1e3)) + except Exception as e: + row.append("%14s" % ("ERR:" + type(e).__name__)) + print(" %-6d" % n_cal + "".join(row)) + + +def _parse_args(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--backend", default="cpu", choices=["cpu", "gpu"]) + p.add_argument("--methods", default="reference,in_loop_B,in_loop_C", + help="comma-separated subset of: %s" % ",".join(METHODS)) + p.add_argument("--n-cal", type=int, default=20) + p.add_argument("--dets", default="H1,L1", help="comma-separated detector prefixes") + p.add_argument("--npts-extrinsic", type=int, default=64) + p.add_argument("--N-window", type=int, default=256) + p.add_argument("--npts", type=int, default=16) + p.add_argument("--repeat", type=int, default=3, help="timing repetitions (best-of)") + p.add_argument("--loglikelihood", default="default", choices=["default", "distmarg"], + help="default helper, or distance-marginalization loglikelihood") + p.add_argument("--real-table", default=None, + help="path to a real util_InitMargTable .npz (distmarg mode) to backtest against") + p.add_argument("--random-cal-weights", action="store_true", + help="inject non-uniform per-realization importance log-weights (validate the weighted reduction)") + p.add_argument("--phase-marginalization", action="store_true") + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--scan-ncal", default=None, + help="comma-separated n_cal values: time each method per n_cal instead of validating") + return p.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + methods = [m.strip() for m in args.methods.split(",") if m.strip()] + unknown = [m for m in methods if m not in METHODS] + if unknown: + raise SystemExit("unknown methods: %s (known: %s)" + % (unknown, list(METHODS))) + dets = tuple(d.strip() for d in args.dets.split(",") if d.strip()) + if args.scan_ncal: + n_cal_list = [int(x) for x in args.scan_ncal.split(",") if x.strip()] + scan_timing(methods, backend=args.backend, n_cal_list=n_cal_list, + repeat=args.repeat, loglikelihood_mode=args.loglikelihood, + real_table=args.real_table, npts_extrinsic=args.npts_extrinsic, + N_window=args.N_window, npts=args.npts, seed=args.seed, dets=dets) + raise SystemExit(0) + ok = run_backtest( + methods, backend=args.backend, repeat=args.repeat, + phase_marginalization=args.phase_marginalization, + loglikelihood_mode=args.loglikelihood, real_table=args.real_table, + random_cal_weights=args.random_cal_weights, + n_cal=args.n_cal, npts_extrinsic=args.npts_extrinsic, + N_window=args.N_window, npts=args.npts, seed=args.seed, dets=dets) + raise SystemExit(0 if ok else 1) diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/breadcrumbs.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/breadcrumbs.py new file mode 100644 index 000000000..07f1223bb --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/breadcrumbs.py @@ -0,0 +1,136 @@ +""" +Breadcrumbs: a portable, integrator-agnostic save/load object for a LEARNED proposal +distribution (Option B in DESIGN_adaptive_driver.md). + +The point of this module is the *interface and schema*, not the model. Today it carries +a Gaussian over calibration spline-node parameters (mean/cov) plus the prior; the schema +reserves an `extrinsic` slot so the SAME object can later carry the extrinsic proposal +(the decade-old "save the extrinsic distribution" goal), and a `kind` field so a +normalizing flow can drop in behind the same load()/sample() interface. + +Stored as a single .npz (arrays + a JSON metadata sidecar string). Keep the schema +STABLE: add fields, do not repurpose them; bump SCHEMA_VERSION on incompatible changes. + +EXTRINSIC slot (schema v2): a portable learned proposal over the EXTRINSIC parameters, +the decade-old "save the extrinsic distribution to inform the next iteration" goal. It is +a list of parameter GROUPS (matching the ILE GMM sampler's gmm_dict structure -- e.g. +(right_ascension, declination), (distance, inclination), (phi_orb, psi)); each group holds +a Gaussian mixture (means/covariances/weights) over its parameters, plus the parameter +NAMES (so the indices reconstruct against the next run's params_ordered) and the sampling +bounds. kind='gmm' now; a normalizing flow drops in later behind the same interface. +""" +from __future__ import division + +import json +import numpy as np + +SCHEMA_VERSION = 2 + + +def save(path, cal=None, extrinsic=None, kind="gaussian", meta=None): + """Write a breadcrumb file. + + cal : dict or None -- the learned calibration proposal, with keys + proposal_mean (dim,), proposal_cov (dim,dim), prior_mean (dim,), prior_sigma (dim,), + node_log_f (n_nodes,), n_nodes_amp (int), dets (list[str]). + Node-vector layout per detector: [amp_0..amp_{N-1}, phase_0..phase_{N-1}], + concatenated over `dets` in order. + extrinsic : dict or None -- the learned EXTRINSIC proposal: + {'kind': 'gmm', + 'groups': [ {'params': [name,...], 'means': (K,d), 'covariances': (K,d,d), + 'weights': (K,), 'bounds': (d,2)}, ... ]}. + One group per gmm_dict block; the GMM is over the group's params in `params` order. + kind : top-level kind tag ('gaussian' for the cal Gaussian; extrinsic kind is its own). + meta : json-able dict (iteration, n_pilot_points, neff_cal, source composite, ...). + """ + d = dict(schema_version=np.int64(SCHEMA_VERSION), kind=str(kind), + has_cal=np.bool_(cal is not None), has_extrinsic=np.bool_(extrinsic is not None), + meta_json=json.dumps(meta or {})) + if cal is not None: + d.update( + cal_proposal_mean=np.asarray(cal["proposal_mean"], dtype=float), + cal_proposal_cov=np.asarray(cal["proposal_cov"], dtype=float), + cal_prior_mean=np.asarray(cal["prior_mean"], dtype=float), + cal_prior_sigma=np.asarray(cal["prior_sigma"], dtype=float), + cal_node_log_f=np.asarray(cal["node_log_f"], dtype=float), + cal_n_nodes_amp=np.int64(cal["n_nodes_amp"]), + cal_dets=np.array(list(cal["dets"]), dtype=object), + ) + if extrinsic is not None: + groups = extrinsic["groups"] + d["ext_kind"] = str(extrinsic.get("kind", "gmm")) + d["ext_n_groups"] = np.int64(len(groups)) + for i, g in enumerate(groups): + d["ext_g%d_params" % i] = np.array(list(g["params"]), dtype=object) + d["ext_g%d_means" % i] = np.asarray(g["means"], dtype=float) + d["ext_g%d_covs" % i] = np.asarray(g["covariances"], dtype=float) + d["ext_g%d_weights" % i] = np.asarray(g["weights"], dtype=float) + d["ext_g%d_bounds" % i] = np.asarray(g["bounds"], dtype=float) + np.savez(path, **d) + return path + + +def load(path): + """Read a breadcrumb file -> dict {schema_version, kind, cal, extrinsic, meta}.""" + z = np.load(path, allow_pickle=True) + ver = int(z["schema_version"]) + if ver > SCHEMA_VERSION: + raise ValueError("breadcrumb schema_version %d newer than supported %d" + % (ver, SCHEMA_VERSION)) + out = dict(schema_version=ver, kind=str(z["kind"]), + meta=json.loads(str(z["meta_json"])), cal=None, extrinsic=None) + if bool(z["has_cal"]): + out["cal"] = dict( + proposal_mean=z["cal_proposal_mean"], proposal_cov=z["cal_proposal_cov"], + prior_mean=z["cal_prior_mean"], prior_sigma=z["cal_prior_sigma"], + node_log_f=z["cal_node_log_f"], n_nodes_amp=int(z["cal_n_nodes_amp"]), + dets=[str(x) for x in z["cal_dets"]], + ) + if "has_extrinsic" in z and bool(z["has_extrinsic"]): + groups = [] + for i in range(int(z["ext_n_groups"])): + groups.append(dict( + params=[str(x) for x in z["ext_g%d_params" % i]], + means=z["ext_g%d_means" % i], covariances=z["ext_g%d_covs" % i], + weights=z["ext_g%d_weights" % i], bounds=z["ext_g%d_bounds" % i], + )) + out["extrinsic"] = dict(kind=str(z["ext_kind"]), groups=groups) + return out + + +if __name__ == "__main__": + # round-trip smoke test + dim = 6 + cal = dict(proposal_mean=np.arange(dim, dtype=float), + proposal_cov=np.eye(dim) * 0.1, + prior_mean=np.zeros(dim), prior_sigma=np.ones(dim), + node_log_f=np.linspace(1, 3, dim // 2), n_nodes_amp=dim // 2, + dets=["H1", "L1", "V1"]) + import tempfile, os + p = os.path.join(tempfile.mkdtemp(), "bc.npz") + save(p, cal=cal, meta=dict(iteration=2, neff_cal=87.3)) + g = load(p) + assert g["kind"] == "gaussian" and g["cal"]["dets"] == ["H1", "L1", "V1"] + assert np.allclose(g["cal"]["proposal_mean"], cal["proposal_mean"]) + assert g["meta"]["iteration"] == 2 + + # extrinsic (GMM) round-trip + ext = dict(kind="gmm", groups=[ + dict(params=["right_ascension", "declination"], + means=np.array([[1.0, 0.2], [4.0, -0.3]]), + covariances=np.array([np.eye(2) * 0.05, np.eye(2) * 0.1]), + weights=np.array([0.6, 0.4]), + bounds=np.array([[0.0, 2 * np.pi], [-np.pi / 2, np.pi / 2]])), + dict(params=["distance", "inclination"], + means=np.array([[500.0, 1.0]]), covariances=np.array([np.diag([1e4, 0.1])]), + weights=np.array([1.0]), bounds=np.array([[1.0, 1000.0], [0.0, np.pi]])), + ]) + p2 = os.path.join(tempfile.mkdtemp(), "bc2.npz") + save(p2, cal=cal, extrinsic=ext, meta=dict(iteration=3)) + g2 = load(p2) + assert g2["extrinsic"]["kind"] == "gmm" and len(g2["extrinsic"]["groups"]) == 2 + assert g2["extrinsic"]["groups"][0]["params"] == ["right_ascension", "declination"] + assert np.allclose(g2["extrinsic"]["groups"][0]["means"], ext["groups"][0]["means"]) + assert np.allclose(g2["extrinsic"]["groups"][1]["covariances"], ext["groups"][1]["covariances"]) + assert g2["cal"] is not None # cal + extrinsic coexist in one breadcrumb + print("PASS: breadcrumb save/load round-trips (cal Gaussian + extrinsic GMM, schema v%d)." % SCHEMA_VERSION) diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/extrinsic_handoff.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/extrinsic_handoff.py new file mode 100644 index 000000000..f4701aa76 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/extrinsic_handoff.py @@ -0,0 +1,232 @@ +""" +Extrinsic handoff: learn a portable proposal over the EXTRINSIC parameters from one +iteration's posterior samples, and seed the NEXT iteration's extrinsic sampler from it. + +This is the decade-old "save the extrinsic distribution to inform the next iteration" +goal, generalized from the calibration pilot's breadcrumb (RIFT.calmarg.breadcrumbs). +The extrinsic posterior barely changes from iteration to iteration (it is set by the data ++ best-fit template, not by the small intrinsic-grid moves), so carrying it forward lets +the sampler start near the answer instead of cold each time. + +GMM-first (this module): RIFT's ensemble sampler (mcsamplerEnsemble) is already seedable -- +its `gmm_dict` maps parameter GROUPS (tuples of indices into params_ordered) to a fitted +`gaussian_mixture_model.gmm`, and a non-None entry is used as the starting proposal. So +the handoff is: + fit_extrinsic_proposal(samples, log_weights, groups, bounds) # per group, RIFT gmm.fit + -> a portable breadcrumb 'extrinsic' dict (means/covs/weights/bounds + param names) + gmm_dict_from_breadcrumb(extrinsic, params_ordered) + -> reconstruct gmm objects, keyed by the dim-group indices, for the next sampler. + +We use RIFT's OWN gmm.fit (the same fitter the sampler uses in update_sampling_prior), so +the stored means/covariances are in exactly the model's internal (normalized) frame and +restore to a byte-identical model -- no coordinate guesswork. A normalizing flow can later +drop in behind the same fit/seed interface (breadcrumb kind != 'gmm'). + +The standard extrinsic groups (matching the ILE GMM gmm_dict) are + (right_ascension, declination), (distance, inclination), (phi_orb, psi). +""" +from __future__ import division + +import numpy as np + +# RIFT's ensemble-sampler GMM (the one gmm_dict expects). Imported lazily so this module +# is importable without the integrator stack (e.g. for breadcrumb round-trip tests). +def _gmm_module(): + import RIFT.integrators.gaussian_mixture_model as GMM + return GMM + + +STANDARD_GROUPS = [ + ["right_ascension", "declination"], + ["distance", "inclination"], + ["phi_orb", "psi"], +] + + +def fit_extrinsic_proposal(samples, log_weights, groups=None, bounds=None, + n_comp=4, max_iters=1000): + """Fit a per-group Gaussian mixture to extrinsic POSTERIOR samples. + + samples : dict {param_name: 1-D array (n,)} -- the extrinsic samples of one run. + log_weights : (n,) importance log-weights (log L + log prior - log sampling_prior), + the same weights the sampler's update_sampling_prior uses. None -> uniform. + groups : list of param-name lists (default STANDARD_GROUPS); only groups whose params + are ALL present in `samples` are fit. + bounds : dict {param_name: (lo, hi)} sampling bounds (required for the GMM frame). + n_comp : mixture components per group. + + Returns the breadcrumb 'extrinsic' dict: + {'kind': 'gmm', 'groups': [ {'params', 'means'(K,d), 'covariances'(K,d,d), + 'weights'(K,), 'bounds'(d,2)}, ... ]}. + """ + GMM = _gmm_module() + if groups is None: + groups = STANDARD_GROUPS + if bounds is None: + raise ValueError("fit_extrinsic_proposal needs per-parameter sampling bounds") + n = len(next(iter(samples.values()))) + lw = np.zeros(n) if log_weights is None else np.asarray(log_weights, dtype=float) + + # Effective sample size of the importance weights (Kish). The source run may have a low + # ESS (calmarg makes the extrinsic integral hard); fitting too many mixture components to + # too few effective samples STARVES the EM fit -- a component collapses onto ~1 sample and + # its covariance goes singular/NaN, poisoning the whole seeded proposal. Cap components by + # ESS below (mirrors the cal pilot's d(d+1)/2 reasoning). + _lwf = lw[np.isfinite(lw)] + if len(_lwf): + _w = np.exp(_lwf - _lwf.max()) + ess = float((_w.sum() ** 2) / np.sum(_w ** 2)) + else: + ess = float(n) + + out_groups = [] + for grp in groups: + if not all(p in samples for p in grp): + continue + d = len(grp) + sample_array = np.column_stack([np.asarray(samples[p], dtype=float) for p in grp]) + grp_bounds = np.array([list(bounds[p]) for p in grp], dtype=float) # (d, 2) + # need >~ (d+2) effective samples per component for a non-degenerate covariance + k = min(n_comp, max(1, int(ess // (d + 2))), max(1, sample_array.shape[0])) + model = GMM.gmm(k, grp_bounds, max_iters=max_iters) + # the model may run on cupy (GPU); move inputs onto its device first. + model.fit(model.identity_convert_togpu(sample_array), + log_sample_weights=model.identity_convert_togpu(lw)) + # model.means/.covariances are lists (length k) in the model's internal frame. + means = np.array([np.asarray(model.identity_convert(m)) for m in model.means]) # (k, d) + covs = np.array([np.asarray(model.identity_convert(c)) for c in model.covariances]) # (k, d, d) + weights = np.asarray(model.identity_convert(model.weights), dtype=float).reshape(-1) # (k,) + # DROP degenerate components. When the source run is starved (few effective samples + # vs n_comp x d), the EM fit collapses a component onto ~1 sample -> singular/NaN + # covariance and NaN mean. A single NaN component poisons the whole seeded proposal + # (NaN sampling-prior -> NaN lnL -> wrong integral). Keep only finite, positive-weight + # components and renormalize; if none survive, skip the group (it stays cold = safe). + good = (np.isfinite(means).all(axis=1) & np.isfinite(covs).reshape(len(covs), -1).all(axis=1) + & np.isfinite(weights) & (weights > 0)) + if not good.any(): + continue + means, covs, weights = means[good], covs[good], weights[good] + weights = weights / weights.sum() + out_groups.append(dict(params=list(grp), means=means, covariances=covs, + weights=weights, bounds=grp_bounds)) + return dict(kind="gmm", groups=out_groups) + + +def reconstruct_gmm(group, max_iters=1000, adapt=True, cov_inflate=1.0): + """Rebuild a RIFT gaussian_mixture_model.gmm from a stored breadcrumb group. + adapt=True -> the seeded components keep adapting in the next run (extrinsics drift a + little); adapt=False freezes them. + + cov_inflate (>=1) widens the seeded covariances (in the model's normalized frame). Default + 1.0 (no inflation): a FROZEN seed (the default seeding mode) should match the source + posterior, not be widened -- inflating only pushes samples past hard bounds (e.g. distance), + where the likelihood is NaN. Inflation is only useful for the adapt=True path (broaden so + the sampler can contract); the freeze path makes it unnecessary. + + All model arrays (means/covariances/weights AND bounds) are moved onto the model's device + (cupy on GPU): the sampler's score()/_normalize write into an xpy.empty array, so a + leftover numpy `self.bounds` raises 'non-scalar numpy.ndarray cannot be used for fill'.""" + GMM = _gmm_module() + means = np.asarray(group["means"]); covs = np.asarray(group["covariances"], dtype=float) * float(cov_inflate) + weights = np.asarray(group["weights"], dtype=float); bounds = np.asarray(group["bounds"], dtype=float) + k = means.shape[0] + model = GMM.gmm(k, bounds, max_iters=max_iters) + model.bounds = model.identity_convert_togpu(bounds) # must match self.xpy (GPU) + model.means = [model.identity_convert_togpu(means[i]) for i in range(k)] + model.covariances = [model.identity_convert_togpu(covs[i]) for i in range(k)] + model.weights = model.identity_convert_togpu(weights) + model.adapt = [bool(adapt)] * k + model.d = means.shape[1] + return model + + +def _permute_group(group, perm): + """Return a copy of a breadcrumb group with its parameter columns reordered by `perm` + (perm[j] = source column index that should land at output position j).""" + means = np.asarray(group["means"])[:, perm] + covs = np.asarray(group["covariances"])[:, perm][:, :, perm] + bounds = np.asarray(group["bounds"])[perm] + return dict(params=[group["params"][j] for j in perm], means=means, + covariances=covs, weights=group["weights"], bounds=bounds) + + +def gmm_dict_from_breadcrumb(extrinsic, params_ordered, adapt=True, existing_keys=None, cov_inflate=1.0): + """Build a gmm_dict {dim_group_tuple: gmm} to SEED mcsamplerEnsemble, from a breadcrumb + 'extrinsic' dict. dim_group_tuple are indices into `params_ordered` (the sampler's + parameter order this run), looked up by parameter NAME -- so the handoff is robust to a + different parameter ordering between runs. Groups whose params are not all present in + params_ordered this run are skipped (with no error). + + `existing_keys` (the sampler's actual gmm_dict keys this run) makes the seed robust to the + WITHIN-group parameter ORDER: the sampler may pair, e.g., (psi, phi_orb) while the + breadcrumb stored (phi_orb, psi). We match each breadcrumb group to the existing key with + the same dim SET, then permute the stored means/covariances/bounds columns into that key's + dim order -- so the seeded model lines up with how the sampler will draw/score it. Without + existing_keys the key is just the breadcrumb's own param order.""" + if extrinsic is None or extrinsic.get("kind") != "gmm": + return {} + name_to_idx = {p: i for i, p in enumerate(params_ordered)} + key_by_set = {frozenset(k): tuple(k) for k in existing_keys} if existing_keys is not None else None + gmm_dict = {} + for group in extrinsic["groups"]: + if not all(p in name_to_idx for p in group["params"]): + continue + grp_idx = [name_to_idx[p] for p in group["params"]] # dim index of each stored column + if key_by_set is not None: + target = key_by_set.get(frozenset(grp_idx)) + if target is None: + continue # sampler has no matching group + else: + target = tuple(grp_idx) + perm = [grp_idx.index(dim) for dim in target] # reorder stored cols -> target order + g = group if perm == list(range(len(perm))) else _permute_group(group, perm) + gmm_dict[target] = reconstruct_gmm(g, adapt=adapt, cov_inflate=cov_inflate) + return gmm_dict + + +# --------------------------------------------------------------------------- +# Proof-of-concept: fit a synthetic multi-cluster extrinsic posterior, round-trip it through +# a breadcrumb, and show a seeded GMM starts on the posterior (vs a cold wide prior). +# --------------------------------------------------------------------------- +if __name__ == "__main__": + from RIFT.calmarg import breadcrumbs + import tempfile, os + + rng = np.random.default_rng(0) + # A bimodal sky posterior (two sky modes) + a unimodal distance/inclination blob. + bounds = {"right_ascension": (0.0, 2 * np.pi), "declination": (-np.pi / 2, np.pi / 2), + "distance": (1.0, 1000.0), "inclination": (0.0, np.pi)} + n = 4000 + mode = rng.random(n) < 0.6 + ra = np.where(mode, rng.normal(1.0, 0.10, n), rng.normal(4.2, 0.15, n)) % (2 * np.pi) + dec = np.where(mode, rng.normal(0.2, 0.08, n), rng.normal(-0.4, 0.10, n)) + dist = np.clip(rng.normal(450.0, 60.0, n), 1, 1000) + incl = np.clip(rng.normal(1.1, 0.2, n), 0, np.pi) + samples = {"right_ascension": ra, "declination": dec, "distance": dist, "inclination": incl} + + ext = fit_extrinsic_proposal(samples, log_weights=None, bounds=bounds, n_comp=3) + print("fit %d groups: %s" % (len(ext["groups"]), [g["params"] for g in ext["groups"]])) + + # round-trip through a breadcrumb + p = os.path.join(tempfile.mkdtemp(), "ext.npz") + breadcrumbs.save(p, extrinsic=ext, meta=dict(iteration=1)) + g = breadcrumbs.load(p) + assert g["extrinsic"]["kind"] == "gmm" + assert np.allclose(g["extrinsic"]["groups"][0]["means"], ext["groups"][0]["means"]) + + # seed: reconstruct the gmm_dict against a (shuffled) params_ordered, draw from the + # seeded sky GMM, and check the draws land on the bimodal posterior (means recovered). + params_ordered = ["distance", "psi", "right_ascension", "phi_orb", "declination", "inclination"] + gmm_dict = gmm_dict_from_breadcrumb(g["extrinsic"], params_ordered) + sky_key = (params_ordered.index("right_ascension"), params_ordered.index("declination")) + assert sky_key in gmm_dict, "sky group not seeded" + sky = gmm_dict[sky_key] + draws = np.asarray(sky.identity_convert(sky.sample(3000))) + # nearest-mode recovery: each true mode should have draws clustered around it + for true_ra in (1.0, 4.2): + near = np.min(np.abs((draws[:, 0] - true_ra + np.pi) % (2 * np.pi) - np.pi)) + assert near < 0.5, "seeded GMM draws miss the sky mode at ra=%.1f" % true_ra + frac_mode1 = np.mean(np.abs(((draws[:, 0] - 1.0 + np.pi) % (2 * np.pi)) - np.pi) < 1.0) + print("seeded sky GMM: draws recover both modes; ~%.0f%% near mode-1 (true ~60%%)" % (100 * frac_mode1)) + print("PASS: extrinsic posterior -> breadcrumb -> seeded GMM reproduces the (bimodal) " + "sky distribution; ready to seed mcsamplerEnsemble's gmm_dict.") diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/generate_realizations.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/generate_realizations.py index a4a8fbfc8..8b79ab1c2 100644 --- a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/generate_realizations.py +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/generate_realizations.py @@ -77,30 +77,76 @@ def nodes_to_spline_coefficients_matrix(n_points): return np.linalg.solve(tmp1, tmp2) -def create_realizations(fname, T_segment,dT, fmin, fmax, n_spline_points, n_realizations): - # NOTE - # - the bilby tool (because it needs high computational efficiency, being done many times) is much harder to read. We will use library code, because we only call it ONCE PER RUN - # - similarly, the LI/bilby tool uses a slightly different representation, because they are trying to avoid transcendental operations to improve efficiency - # Conversion tool - # spline_matrix = nodes_to_spline_coefficients_matrix(n_spline_points) - # STEP 0: logarithmic frequency spacing in positive freequency -# print(fname, T_segment, dT, fmin, fmax, n_spline_points, n_realizations) - +def node_prior(fname, fmin, fmax, n_spline_points): + """Return the calibration PRIOR over spline nodes for one detector, as the + diagonal Gaussian implied by the envelope file. + + The per-detector node vector is laid out as + [amp_0 .. amp_{N-1}, phase_0 .. phase_{N-1}] (N = n_spline_points) + with amp node i ~ N(median_amp_i, sigma_amp_i) and phase node i ~ + N(median_phase_i, sigma_phase_i), independent (this is exactly the prior that + create_realizations() draws from). + + Returns dict(mean, sigma, node_log_f, n_nodes_amp) -- mean/sigma length 2N. + """ log_freq_spline_locations = np.linspace(np.log10(fmin), np.log10(fmax), n_spline_points) - - # Localize data to location dat_amp, dat_phase = retrieve_envelope_from_file(fname, frequency_array=10**log_freq_spline_locations) - # Create random spline realizations + mean = np.concatenate([dat_amp[:, 1], dat_phase[:, 1]]) + sigma = np.concatenate([dat_amp[:, 2], dat_phase[:, 2]]) + return dict(mean=mean, sigma=sigma, node_log_f=log_freq_spline_locations, + n_nodes_amp=int(n_spline_points)) + + +def prior_cal_breadcrumb_dict(env_dir, dets, fmin, fmax, n_spline_points, fmin_ifo=None): + """Build the 'cal' breadcrumb dict for the broad PRIOR, with proposal == prior. + + Suitable as an iteration-0 placeholder breadcrumb: seeding from it + (seed_realizations_from_breadcrumb) draws cal realizations from the prior with ZERO + importance weights (log prior - log proposal = 0), i.e. it is equivalent to the cold + prior draws -- but, unlike a 0-byte placeholder, it LOADS cleanly (so an older ILE binary + that does not guard against an empty placeholder will not crash on it). + + Layout matches seed_realizations_from_breadcrumb: the full node vector is concatenated per + detector in `dets` order as [det0_amp_0..,det0_phase_0..,det1_amp..,...]; dim = 2N*len(dets). + """ + import os + means = []; sigmas = []; node_log_f = None + for ifo in dets: + fmin_here = fmin + if fmin_ifo and ifo in fmin_ifo: + fmin_here = fmin_ifo[ifo] + pr = node_prior(os.path.join(env_dir, ifo + ".txt"), fmin_here, fmax, n_spline_points) + means.append(pr["mean"]); sigmas.append(pr["sigma"]) + if node_log_f is None: + node_log_f = pr["node_log_f"] + prior_mean = np.concatenate(means); prior_sigma = np.concatenate(sigmas) + return dict(proposal_mean=prior_mean, proposal_cov=np.diag(prior_sigma ** 2), + prior_mean=prior_mean, prior_sigma=prior_sigma, + node_log_f=node_log_f, n_nodes_amp=int(n_spline_points), dets=list(dets)) + + +def _draw_amp_phase_nodes(dat_amp, dat_phase, n_spline_points, n_realizations): + """Draw amp/phase spline nodes from the prior, in the EXACT random-number order + create_realizations() has always used (so seeded behavior is byte-identical).""" amp_rand_array = np.zeros((n_spline_points, n_realizations)) phase_rand_array = np.zeros((n_spline_points, n_realizations)) -# print(amp_rand_array.shape, phase_rand_array.shape) - # Create random amplitudes, phases - # - not efficient, for loop : use matrix operations to speed up! for indx in np.arange(n_spline_points): - amp_rand_array[indx,:] = np.random.normal(loc=dat_amp[indx,1], scale=dat_amp[indx,2], size=n_realizations) - phase_rand_array[indx,:] = np.random.normal(loc=dat_phase[indx,1], scale=dat_phase[indx,2], size=n_realizations) + amp_rand_array[indx, :] = np.random.normal(loc=dat_amp[indx, 1], scale=dat_amp[indx, 2], size=n_realizations) + phase_rand_array[indx, :] = np.random.normal(loc=dat_phase[indx, 1], scale=dat_phase[indx, 2], size=n_realizations) + return amp_rand_array, phase_rand_array - # Create realizations (complex-valued array for TWO_SIDED system + +def build_realizations_from_nodes(amp_rand_array, phase_rand_array, T_segment, dT, + fmin, fmax, log_freq_spline_locations): + """Build the complex two-sided per-realization calibration factor array from + spline-node values. Factored out of create_realizations() so the SAME spline + construction is reused both for prior draws and for proposal-seeded draws + (seed_realizations_from_breadcrumb). + + amp_rand_array, phase_rand_array : (n_spline_points, n_realizations). + Returns dat_out (npts_seg, n_realizations) complex, unity outside [fmin,fmax]. + """ + n_realizations = amp_rand_array.shape[1] deltaF_seg = 1./T_segment npts_seg = int(T_segment/dT) # Match array locations from lalsimutils.evaluate_fvals! @@ -108,7 +154,7 @@ def create_realizations(fname, T_segment,dT, fmin, fmax, n_spline_points, n_rea mask_positive = freq_locations_physical > 0 mask_negative = freq_locations_physical < 0 mask_in_range = np.logical_and(np.abs(freq_locations_physical) >= fmin , np.abs(freq_locations_physical) <= fmax) - + dat_out = np.ones((npts_seg, n_realizations),dtype=complex) # default factor is unity # Loop over realizations, build up spline @@ -124,11 +170,143 @@ def create_realizations(fname, T_segment,dT, fmin, fmax, n_spline_points, n_rea dat_out[mask_plus, indx] = cs_amp( log10_freq_pos_in_range )*np.exp(1j*cs_phase(log10_freq_pos_in_range)) dat_out[mask_minus, indx] = cs_amp( log10_minus_freq_neg_in_range )*np.exp(-1j*cs_phase(log10_minus_freq_neg_in_range)) -# print(log10_freq_pos_in_range, cs_amp(log10_freq_pos_in_range), cs_phase(log10_freq_pos_in_range) ) -# print(dat_out[mask_plus]) return dat_out +def create_realizations(fname, T_segment,dT, fmin, fmax, n_spline_points, n_realizations): + # NOTE + # - the bilby tool (because it needs high computational efficiency, being done many times) is much harder to read. We will use library code, because we only call it ONCE PER RUN + # - similarly, the LI/bilby tool uses a slightly different representation, because they are trying to avoid transcendental operations to improve efficiency + # Conversion tool + # spline_matrix = nodes_to_spline_coefficients_matrix(n_spline_points) + # STEP 0: logarithmic frequency spacing in positive freequency +# print(fname, T_segment, dT, fmin, fmax, n_spline_points, n_realizations) + + log_freq_spline_locations = np.linspace(np.log10(fmin), np.log10(fmax), n_spline_points) + + # Localize data to location + dat_amp, dat_phase = retrieve_envelope_from_file(fname, frequency_array=10**log_freq_spline_locations) + # Create random spline realizations (prior draws -- same RNG order as always) + amp_rand_array, phase_rand_array = _draw_amp_phase_nodes(dat_amp, dat_phase, n_spline_points, n_realizations) + + # Create realizations (complex-valued array for TWO_SIDED system + return build_realizations_from_nodes(amp_rand_array, phase_rand_array, T_segment, dT, + fmin, fmax, log_freq_spline_locations) + + +def draw_prior_realizations_with_nodes(env_dir, dets, T_segment, dT, fmin, fmax, + n_spline_points, n_realizations, + fmin_ifo=None, rng=None): + """Draw PRIOR calibration realizations AND keep the spline-node draws. + + Same prior as create_realizations(), but it returns the node vectors (which + create_realizations discards) so a pilot can fit a proposal over them. Used by + the ILE --calibration-dump-responsibilities path. + + env_dir : directory of per-detector envelope files .txt. + dets : detector order; the returned node vector is concatenated per det as + [det0_amp, det0_phase, det1_amp, det1_phase, ...] (breadcrumb layout). + + Returns dict with: + realizations : {ifo: (npts_seg, n_realizations) complex} + nodes : (n_realizations, 2*n_spline_points*len(dets)) prior draws + prior_mean : (dim,) diagonal-Gaussian prior mean over the full node vector + prior_sigma : (dim,) prior sigma + node_log_f : (n_spline_points,) log10 spline node frequencies (det 0) + n_nodes_amp : n_spline_points + dets : list + """ + import os + if rng is None: + rng = np.random.default_rng() + priors = [] + for ifo in dets: + fmin_here = fmin + if fmin_ifo and ifo in fmin_ifo: + fmin_here = fmin_ifo[ifo] + priors.append(node_prior(os.path.join(env_dir, ifo + ".txt"), fmin_here, fmax, n_spline_points)) + prior_mean = np.concatenate([p["mean"] for p in priors]) + prior_sigma = np.concatenate([p["sigma"] for p in priors]) + dim = prior_mean.shape[0] + # diagonal-Gaussian prior draws over the full node vector + nodes = prior_mean[None, :] + prior_sigma[None, :] * rng.standard_normal((n_realizations, dim)) + + n_amp = int(n_spline_points) + dim_per_det = 2 * n_amp + realizations = {} + for i_det, ifo in enumerate(dets): + fmin_here = fmin + if fmin_ifo and ifo in fmin_ifo: + fmin_here = fmin_ifo[ifo] + log_freq_spline_locations = np.linspace(np.log10(fmin_here), np.log10(fmax), n_spline_points) + block = nodes[:, i_det*dim_per_det:(i_det+1)*dim_per_det] + realizations[ifo] = build_realizations_from_nodes( + block[:, :n_amp].T, block[:, n_amp:].T, T_segment, dT, fmin_here, fmax, + log_freq_spline_locations) + return dict(realizations=realizations, nodes=nodes, prior_mean=prior_mean, + prior_sigma=prior_sigma, node_log_f=priors[0]["node_log_f"], + n_nodes_amp=n_amp, dets=list(dets)) + + +def seed_realizations_from_breadcrumb(bc, T_segment, dT, fmin, fmax, n_spline_points, + n_realizations, fmin_ifo=None, rng=None): + """Draw n_realizations calibration factors per detector from a LEARNED Gaussian + proposal (a breadcrumb), and return the Phase-0 importance weights. + + This is the production seed path (Option C): instead of drawing cal nodes from + the broad prior, draw them from the consolidated pilot proposal (concentrated on + the high-likelihood cal region), and carry log_w = log prior - log proposal so the + marginalization stays unbiased. + + bc : breadcrumb dict (breadcrumbs.load(...)) OR its ["cal"] sub-dict. The proposal + is a joint Gaussian over the FULL multi-detector node vector, concatenated over + bc["dets"] in order: [det0_amp, det0_phase, det1_amp, det1_phase, ...]. + fmin : scalar template fmin (used for all dets) OR ignored per-det if fmin_ifo given. + fmin_ifo : optional {ifo: fmin} for per-detector low-frequency cutoffs. + rng : numpy Generator (default: a fresh default_rng()). + + Returns (dat_out_dict, cal_log_weights, nodes): + dat_out_dict : {ifo: (npts_seg, n_realizations) complex} + cal_log_weights: (n_realizations,) = log prior - log proposal (shared across dets, + since it is ONE joint draw per realization). + nodes : (n_realizations, dim) the proposal-drawn node vectors (so a pilot + can refit a proposal from a seeded run). + """ + from RIFT.calmarg import adaptive + cal = bc["cal"] if (isinstance(bc, dict) and "cal" in bc) else bc + if rng is None: + rng = np.random.default_rng() + mean = np.asarray(cal["proposal_mean"], dtype=float) + cov = np.asarray(cal["proposal_cov"], dtype=float) + prior_mean = np.asarray(cal["prior_mean"], dtype=float) + prior_sigma = np.asarray(cal["prior_sigma"], dtype=float) + dets = list(cal["dets"]) + n_amp = int(cal["n_nodes_amp"]) + dim_per_det = 2 * n_amp + assert mean.shape == (dim_per_det * len(dets),), \ + "breadcrumb proposal dim %s != 2*n_nodes_amp*len(dets)=%d" % (mean.shape, dim_per_det*len(dets)) + + # Draw the full multi-detector node vector from the proposal; importance weights. + nodes = rng.multivariate_normal(mean, cov, size=n_realizations) # (n_real, dim) + log_q = adaptive._mvn_logpdf(nodes, mean, cov) + log_p = adaptive.log_prior(nodes, prior_mean, prior_sigma) + cal_log_weights = log_p - log_q + + dat_out_dict = {} + for i_det, ifo in enumerate(dets): + fmin_here = fmin + if fmin_ifo and ifo in fmin_ifo: + fmin_here = fmin_ifo[ifo] + log_freq_spline_locations = np.linspace(np.log10(fmin_here), np.log10(fmax), n_spline_points) + block = nodes[:, i_det*dim_per_det:(i_det+1)*dim_per_det] # (n_real, 2N) + amp_rand_array = block[:, :n_amp].T # (N, n_real) + phase_rand_array = block[:, n_amp:].T + dat_out_dict[ifo] = build_realizations_from_nodes( + amp_rand_array, phase_rand_array, T_segment, dT, fmin_here, fmax, + log_freq_spline_locations) + return dat_out_dict, cal_log_weights, nodes + + if __name__ == "__main__": from matplotlib import pyplot as plt diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/pilot.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/pilot.py new file mode 100644 index 000000000..c6cbb5149 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/pilot.py @@ -0,0 +1,167 @@ +""" +Calibration pilot / brute-force reference (Options A + C in DESIGN_adaptive_driver.md). + + A (brute-force reference, the ONLY validation): marginalize cal with a large PRIOR set, + converged. Slow, ground truth. `brute_force_logZcal`. + C (production pilot): harvest the top-fraction high-lnL points from the previous + iteration's *.composite, do full cal there, fit a Gaussian proposal, and seed the + next iteration's cal realizations with importance weights. `harvest_high_L`, + `fit_pilot_proposal`, `consolidate`, `seed_cal`. + +The actual DAG wiring (pilot_N || wide_N, consolidation barrier, pilot_N -> wide_{N+1}, +the iteration cap) lives in the pipeline builder; it is STUBBED here with TODOs so the +plan is remembered. The numeric core (fit, seed, brute-vs-seeded agreement) is real and +tested below. +""" +from __future__ import division + +import numpy as np +from scipy.special import logsumexp + +from RIFT.calmarg import adaptive, breadcrumbs + + +# --------------------------------------------------------------------------- +# C: harvest pilot points from a previous iteration's composite +# --------------------------------------------------------------------------- +def harvest_high_L(composite_path, top_fraction=0.05, lnL_col="lnL", max_points=512): + """Return the indices+rows of the top `top_fraction` of evaluated points by lnL from + a RIFT *.composite file (whitespace, named header). These are the pilot points where + we will do full calibration (the cal posterior is ~the same across the high-L region, + so a handful suffice).""" + arr = np.atleast_2d(np.genfromtxt(composite_path, names=True)) + names = arr.dtype.names + col = lnL_col if lnL_col in names else _guess_lnL_col(names) + lnL = arr[col] + n_keep = max(1, int(np.ceil(len(lnL) * top_fraction))) + if max_points: + n_keep = min(n_keep, max_points) + order = np.argsort(lnL)[::-1][:n_keep] + return order, arr[order] + + +def _guess_lnL_col(names): + for cand in ("lnL", "lnL_raw", "loglikelihood", "log_likelihood"): + if cand in names: + return cand + raise KeyError("no lnL-like column in composite; columns=%s" % (names,)) + + +# --------------------------------------------------------------------------- +# A: brute-force reference (prior-only, large n_cal) +# --------------------------------------------------------------------------- +def brute_force_logZcal(log_L): + """Self-normalized cal-marginalized log-likelihood from a LARGE PRIOR cal set: + log Z_cal = logmeanexp(log_L) (importance weights are uniform for prior draws). + Returns (logZ, neff).""" + log_L = np.asarray(log_L) + logZ = logsumexp(log_L) - np.log(len(log_L)) + neff = adaptive.neff_from_logweights(log_L) + return float(logZ), neff + + +# --------------------------------------------------------------------------- +# C: fit a pilot proposal, seed the next run, consolidate breadcrumbs +# --------------------------------------------------------------------------- +def fit_pilot_proposal(nodes, log_resp, prior_mean, prior_sigma, node_log_f, + n_nodes_amp, dets, beta=1.0, meta=None): + """Fit a Gaussian cal proposal from pilot evaluations and package it as a breadcrumb + `cal` dict. `log_resp` = posterior responsibility (log_w + log integral L) per + realization, averaged/accumulated over the harvested pilot points by the caller.""" + mean, cov = adaptive.fit_proposal(nodes, log_resp, beta) + return dict(proposal_mean=mean, proposal_cov=cov, + prior_mean=np.asarray(prior_mean), prior_sigma=np.asarray(prior_sigma), + node_log_f=np.asarray(node_log_f), n_nodes_amp=int(n_nodes_amp), + dets=list(dets)) + + +def seed_cal(cal_proposal, n_cal, rng=None): + """Draw `n_cal` cal node vectors from the learned proposal and return + (nodes, log_weights) where log_weights = log prior - log proposal (Phase 0 importance + weights for the marginalization). Feed nodes through + adaptive.nodes_to_cal_factors(...) per detector to get the actual cal factors.""" + rng = rng or np.random.default_rng() + mean = np.asarray(cal_proposal["proposal_mean"]) + cov = np.asarray(cal_proposal["proposal_cov"]) + nodes = rng.multivariate_normal(mean, cov, size=n_cal) + log_q = adaptive._mvn_logpdf(nodes, mean, cov) + log_p = adaptive.log_prior(nodes, np.asarray(cal_proposal["prior_mean"]), + np.asarray(cal_proposal["prior_sigma"])) + return nodes, (log_p - log_q) + + +def consolidate(breadcrumb_paths, out_path=None): + """Combine cal proposals from several pilot breadcrumbs into one (the consolidation + job between iteration N and N+1). Gaussian case: precision-weighted combination + (a moment-matched product/average of the per-pilot Gaussians).""" + cals = [breadcrumbs.load(p)["cal"] for p in breadcrumb_paths] + cals = [c for c in cals if c is not None] + if not cals: + raise ValueError("no cal proposals to consolidate") + # precision-weighted mean, average covariance (robust, simple) + Ps = [np.linalg.inv(c["proposal_cov"]) for c in cals] + P = np.sum(Ps, axis=0) + cov = np.linalg.inv(P) + mean = cov @ np.sum([Pi @ c["proposal_mean"] for Pi, c in zip(Ps, cals)], axis=0) + out = dict(cals[0]); out["proposal_mean"] = mean; out["proposal_cov"] = cov + if out_path: + breadcrumbs.save(out_path, cal=out, meta=dict(consolidated_from=len(cals))) + return out + + +# --------------------------------------------------------------------------- +# DAG job stubs (Option C pipeline wiring -- TODO; see DESIGN_adaptive_driver.md) +# --------------------------------------------------------------------------- +def pilot_job(prev_composite, data_args, out_breadcrumb, top_fraction=0.05, n_cal_full=1000): + """STUB. pilot_N: harvest top-fraction points from prev_composite, run FULL cal at + each (large prior n_cal, parallel), fit the proposal, write a breadcrumb. + TODO: wire to the ILE precompute/likelihood to get per-point per-realization lnL.""" + raise NotImplementedError("pilot_job: pipeline wiring TODO (see DESIGN_adaptive_driver.md)") + + +def consolidation_job(pilot_breadcrumbs, out_breadcrumb): + """consolidation_N: collect pilot breadcrumbs -> one consolidated proposal that seeds + wide_{N+1}. (The numeric core is `consolidate` above.)""" + return consolidate(pilot_breadcrumbs, out_path=out_breadcrumb) + + +# --------------------------------------------------------------------------- +# A-vs-C validation: brute force == pilot-seeded on Z_cal, at far higher efficiency +# --------------------------------------------------------------------------- +if __name__ == "__main__": + rng = np.random.default_rng(7) + dim = 8 + prior_mean = np.zeros(dim); prior_sigma = np.ones(dim) + true_node = prior_mean + 2.0 * prior_sigma * rng.standard_normal(dim) / np.sqrt(dim) + like_sigma = 0.4 + + def log_like(nodes): # extrinsic-marg log-like proxy (no prior) + z = (nodes - true_node) / like_sigma + return -0.5 * np.sum(z * z, axis=1) + + # A: brute force, large PRIOR set (ground truth) + big = rng.multivariate_normal(prior_mean, np.diag(prior_sigma ** 2), size=20000) + logZ_brute, neff_brute = brute_force_logZcal(log_like(big)) + print("A brute force : logZcal=%.4f neff=%.1f / 20000" % (logZ_brute, neff_brute)) + + # C: learn the proposal (pilot), then seed a SMALL set and importance-weight + res = adaptive.adaptive_cal(log_like, prior_mean, prior_sigma, n_nodes_amp=dim // 2, + n_real=300, n_iter=6, rng=rng) + cal = dict(proposal_mean=res["proposal_mean"], proposal_cov=res["proposal_cov"], + prior_mean=prior_mean, prior_sigma=prior_sigma, + node_log_f=np.linspace(1, 3, dim // 2), n_nodes_amp=dim // 2, + dets=["H1", "L1", "V1"]) + nodes, log_w = seed_cal(cal, n_cal=300, rng=rng) + log_resp = log_w + log_like(nodes) + # UNBIASED importance estimate Z_cal = (1/M) sum_c w_c L_c, w_c = prior/proposal, + # E[w]=1 -> normalize by log(M), NOT logsumexp(log_w) (the biased self-normalized form) + logZ_seeded = logsumexp(log_resp) - np.log(len(nodes)) + neff_seeded = adaptive.neff_from_logweights(log_resp) + print("C pilot-seeded: logZcal=%.4f neff=%.1f / 300" % (logZ_seeded, neff_seeded)) + print("agreement |dlogZ| = %.4f ; efficiency gain x%.0f" + % (abs(logZ_seeded - logZ_brute), (neff_seeded / 300) / (neff_brute / 20000))) + + assert abs(logZ_seeded - logZ_brute) < 0.1, "pilot-seeded Z disagrees with brute force" + assert (neff_seeded / 300) > 20 * (neff_brute / 20000), "pilot did not improve efficiency" + print("\nPASS: pilot-seeded (C) reproduces the brute-force reference (A) at far higher " + "effective sampling.") diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_cal_mc_error.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_cal_mc_error.py new file mode 100644 index 000000000..23461ecb8 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_cal_mc_error.py @@ -0,0 +1,103 @@ +"""Unit test for adaptive.cal_mc_error_from_components. + +Validates the delta-method calibration MC error budget two ways, no GPU/lal needed: + +1. LOGNORMAL CLOSED FORM: for lnL_c ~ N(mu, s^2) iid, Var(ln Zhat) = (e^{s^2}-1)/n_cal + (to leading order). The estimator must reproduce this from a single draw set. + +2. BRUTE FORCE: redraw the cal set many times, measure the true scatter of + ln Zhat = logsumexp_c(lnL_c) - log n_cal directly, compare. + +Also checks the importance-weighted (cal_log_weights) path is unbiased, and that the +extrinsic batch weighting reduces to the same answer when responsibilities are +extrinsic-independent. + +Run: python -m RIFT.calmarg.test_cal_mc_error +""" +import numpy as np +from scipy.special import logsumexp + +from RIFT.calmarg.adaptive import cal_mc_error_from_components + + +def _true_scatter(sig, n_cal, trials, rng): + x = rng.normal(0.0, sig, size=(trials, n_cal)) + return np.std(logsumexp(x, axis=1) - np.log(n_cal)) + + +def test_lognormal_closed_form(): + rng = np.random.default_rng(7) + for sig in [0.3, 0.7, 1.0]: + n_cal = 400 + analytic = np.sqrt((np.exp(sig ** 2) - 1.0) / n_cal) + # average the estimator over independent draw sets (the estimator is itself + # a one-draw-set statistic, so compare in expectation) + est = [] + for _ in range(200): + comp = rng.normal(0.0, sig, size=(1, n_cal)) # 1 extrinsic sample suffices + s, neff, a = cal_mc_error_from_components(comp) + est.append(s) + assert abs(a.sum() - 1.0) < 1e-12 + est = np.mean(est) + assert abs(est - analytic) / analytic < 0.15, (sig, est, analytic) + print("test_lognormal_closed_form: OK") + + +def test_brute_force_scatter(): + rng = np.random.default_rng(11) + sig, n_cal = 1.2, 100 + truth = _true_scatter(sig, n_cal, 20000, rng) + est = np.mean([cal_mc_error_from_components(rng.normal(0, sig, (1, n_cal)))[0] + for _ in range(300)]) + # delta method degrades as neff_cal drops; require agreement within 30% + assert abs(est - truth) / truth < 0.30, (est, truth) + print("test_brute_force_scatter: OK (true {:.3f}, est {:.3f})".format(truth, est)) + + +def test_neff_dominated(): + # one realization dominating -> neff ~ 1 and a loud (lower-bound) sigma + comp = np.full((4, 50), -100.0) + comp[:, 3] = 0.0 + s, neff, a = cal_mc_error_from_components(comp) + assert neff < 1.5 + assert np.argmax(a) == 3 + print("test_neff_dominated: OK (neff {:.2f})".format(neff)) + + +def test_importance_weights_consistency(): + # drawing from a proposal with weights must agree with prior draws in expectation + rng = np.random.default_rng(3) + n_cal, sig = 800, 0.8 + # prior draws + s_prior = np.mean([cal_mc_error_from_components(rng.normal(0, sig, (1, n_cal)))[0] + for _ in range(100)]) + # 'proposal' = prior here, with identically zero log-weights: must match exactly in law + s_w = np.mean([cal_mc_error_from_components(rng.normal(0, sig, (1, n_cal)), + cal_log_weights=np.zeros(n_cal))[0] + for _ in range(100)]) + assert abs(s_prior - s_w) / s_prior < 0.2 + print("test_importance_weights_consistency: OK") + + +def test_extrinsic_batch_weighting(): + # responsibilities ~extrinsic-independent: a batch with a common per-sample offset + # (the extrinsic-dependent part) must give the same answer as a single sample. + rng = np.random.default_rng(5) + n_cal = 200 + base = rng.normal(0, 1.0, n_cal) + offsets = rng.normal(0, 5.0, 64) # huge extrinsic spread + comp = offsets[:, None] + base[None, :] + s_batch, neff_b, _ = cal_mc_error_from_components(comp) + s_one, neff_1, _ = cal_mc_error_from_components(base[None, :]) + assert abs(s_batch - s_one) < 1e-10 + assert abs(neff_b - neff_1) < 1e-8 + print("test_extrinsic_batch_weighting: OK") + + +if __name__ == "__main__": + test_lognormal_closed_form() + test_brute_force_scatter() + test_neff_dominated() + test_importance_weights_consistency() + test_extrinsic_batch_weighting() + print("ALL OK") diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_calmarg_reduction.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_calmarg_reduction.py new file mode 100644 index 000000000..82b4de396 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_calmarg_reduction.py @@ -0,0 +1,100 @@ +""" +Validate the n_cal>1 calibration-marginalization reduction in +DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop against a brute-force reference: +running the (unchanged) n_cal==1 path on each realization block separately and +combining with lnL = logsumexp_c(lnL_c) - log(n_cal). + +Runs entirely on CPU (xpy=np), so no GPU required. +""" +import numpy as np +import lal +from scipy.special import logsumexp +import RIFT.likelihood.factored_likelihood as fl + +rng = np.random.default_rng(1234) + +det = "H1" +n_lms = 2 +N_window = 256 # per-realization buffer length (must exceed sky time-delay spread + npts) +npts = 16 # integration sub-window (len(tvals)) +n_cal = 5 +npts_extrinsic = 6 +deltaT = 1.0 / 4096 + +# lookup table of (l,m) pairs +lookupNKDict = {det: np.array([[2, 2], [2, -2]], dtype=int)} + +# concatenated rholm timeseries: (n_lms, N_window*n_cal); block c is realization c +npts_full = N_window * n_cal +rho = (rng.standard_normal((n_lms, npts_full)) + 1j*rng.standard_normal((n_lms, npts_full))) +rholmsArrayDict = {det: rho} + +# template-template cross terms (Hermitian-ish; values don't matter for the identity) +U = (rng.standard_normal((n_lms, n_lms)) + 1j*rng.standard_normal((n_lms, n_lms))) +U = U + U.conj().T +V = (rng.standard_normal((n_lms, n_lms)) + 1j*rng.standard_normal((n_lms, n_lms))) +ctUArrayDict = {det: U} +ctVArrayDict = {det: V} + +epochDict = {det: 0.0} + +# extrinsic parameter vector (mock P_vec) +class PV: pass +P = PV() +P.phi = rng.uniform(0, 2*np.pi, npts_extrinsic) +P.theta = rng.uniform(0.2, np.pi-0.2, npts_extrinsic) +P.psi = rng.uniform(0, np.pi, npts_extrinsic) +P.incl = rng.uniform(0.2, np.pi-0.2, npts_extrinsic) +P.phiref = rng.uniform(0, 2*np.pi, npts_extrinsic) +P.dist = np.full(npts_extrinsic, 500.0) * (lal.PC_SI*1e6) # 500 Mpc +P.tref = 1000000000.0 +P.deltaT = deltaT +# Place the integration window near the middle of the buffer so ifirst stays in +# [0, N_window-npts] for all sky positions (TimeDelayFromEarthCenter is +-0.021s). +epochDict[det] = P.tref - 0.03 + +tvals = np.linspace(-npts//2*deltaT, npts//2*deltaT, npts) + +# --- reference: per-block n_cal==1 evaluations, combined by hand --- +lnL_blocks = np.zeros((n_cal, npts_extrinsic)) +for c in range(n_cal): + block = rho[:, c*N_window:(c+1)*N_window].copy() + lnL_blocks[c] = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + tvals, P, lookupNKDict, {det: block}, ctUArrayDict, ctVArrayDict, epochDict, + Lmax=2, xpy=np, n_cal=1) +lnL_ref = logsumexp(lnL_blocks, axis=0) - np.log(n_cal) + +# --- new path: single call with n_cal>1 --- +lnL_new = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + tvals, P, lookupNKDict, rholmsArrayDict, ctUArrayDict, ctVArrayDict, epochDict, + Lmax=2, xpy=np, n_cal=n_cal) + +print("reference lnL :", np.array(lnL_ref)) +print("in-loop lnL :", np.array(lnL_new)) +maxerr = np.max(np.abs(np.array(lnL_new) - np.array(lnL_ref))) +print("max abs error :", maxerr) +assert maxerr < 1e-9, "MISMATCH: cal-marg reduction != brute-force reference" + +# --- return_cal_components: raw per-realization integrated log L, (npts_extrinsic, n_cal) --- +# Collapsing it by hand must reproduce the cal-marg lnL: +# lnL_marg(i) = logsumexp_c comp[i,c] - log(n_cal) (uniform weights) +comp = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + tvals, P, lookupNKDict, rholmsArrayDict, ctUArrayDict, ctVArrayDict, epochDict, + Lmax=2, xpy=np, n_cal=n_cal, return_cal_components=True) +comp = np.array(comp) +assert comp.shape == (npts_extrinsic, n_cal), "cal_components shape %s" % (comp.shape,) +# each column must equal the per-block n_cal==1 evaluation (raw, no weight) +assert np.max(np.abs(comp.T - lnL_blocks)) < 1e-9, "cal_components != per-block reference" +lnL_from_comp = logsumexp(comp, axis=1) - np.log(n_cal) +assert np.max(np.abs(lnL_from_comp - np.array(lnL_new))) < 1e-9, \ + "collapse of cal_components != cal-marg lnL" +print("cal_components check: max err vs blocks = %.2e ; collapse matches lnL" % + np.max(np.abs(comp.T - lnL_blocks))) + +# --- also confirm n_cal==1 on the full concat == block-0 evaluation (regression) --- +lnL_n1_full = fl.DiscreteFactoredLogLikelihoodViaArrayVectorNoLoop( + tvals, P, lookupNKDict, {det: rho[:, :N_window].copy()}, ctUArrayDict, ctVArrayDict, + epochDict, Lmax=2, xpy=np, n_cal=1) +assert np.allclose(np.array(lnL_n1_full), np.array(lnL_blocks[0])), "block-0 regression failed" + +print("\nPASS: in-loop calibration marginalization matches brute-force reference.") diff --git a/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_precompute_alignment.py b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_precompute_alignment.py new file mode 100644 index 000000000..8439571bd --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/calmarg/test_precompute_alignment.py @@ -0,0 +1,82 @@ +""" +Regression test for the calibration-marginalization PRECOMPUTE time alignment. + +This exercises the real PrecomputeLikelihoodTerms / ComputeModeIPTimeSeries path +(unlike backtest_calmarg.py / test_calmarg_reduction.py, which feed synthetic rholms +with a hand-set epoch and therefore cannot catch a precompute-alignment bug). + +With the calibration factor set to 1 (identity), the calibration-marginalized rholm +series must, block by block, reproduce the non-calibration rholm series -- same data AND +the same epoch. A wrong epoch on the concatenated series (the bug fixed in this branch) +shifts ifirst into the wrong realization block downstream, the signal is zeroed, and the +calmarg likelihood collapses. The epoch check below is the one that fails on that bug. + +Runs on CPU (no GPU needed). +""" +import numpy as np +import lal +import lalsimulation as lalsim + +import RIFT.lalsimutils as lalsimutils +import RIFT.likelihood.factored_likelihood as fl + +fSample = 4096.0 +fmin = 30.0 +fmax = 1700.0 +event_time = 1000000000.0 +t_window = 0.1 + +# A short BBH so the test is fast. +Psig = lalsimutils.ChooseWaveformParams( + fmin=fmin, radec=True, incl=0.0, phiref=0.0, theta=0.2, phi=0.0, psi=0.0, + m1=30 * lal.MSUN_SI, m2=30 * lal.MSUN_SI, + detector='H1', dist=200e6 * lal.PC_SI, deltaT=1. / fSample, + tref=event_time, deltaF=1. / 4.) + +data_dict = {} +for det in ("H1", "L1", "V1"): + P = Psig.manual_copy(); P.detector = det + data_dict[det] = lalsimutils.non_herm_hoff(P) +psd_dict = {det: lalsim.SimNoisePSDaLIGOZeroDetHighPower for det in data_dict} + +Lmax = 2 +n_cal = 5 + +# baseline (no calibration marginalization) +rholms_intp_b, ct_b, ctV_b, rholms_base, snr_b, _ = fl.PrecomputeLikelihoodTerms( + event_time, t_window, Psig, data_dict, psd_dict, Lmax, fmax, + analyticPSD_Q=True, verbose=False, quiet=True, skip_interpolation=True) + +# calibration marginalization with the IDENTITY calibration (factor == 1) +cal_real = {det: np.ones((data_dict[det].data.length, n_cal), dtype=complex) + for det in data_dict} +rholms_intp_c, ct_c, ctV_c, rholms_cal, snr_c, _ = fl.PrecomputeLikelihoodTerms( + event_time, t_window, Psig, data_dict, psd_dict, Lmax, fmax, + analyticPSD_Q=True, verbose=False, quiet=True, skip_interpolation=True, + calibration_realizations=cal_real) + +ok = True +for det in data_dict: + for pair in rholms_base[det]: + base = rholms_base[det][pair] + cal = rholms_cal[det][pair] + N_window = base.data.length + # (1) concatenated length is n_cal blocks + assert cal.data.length == N_window * n_cal, \ + "%s %s: cal length %d != %d*%d" % (det, pair, cal.data.length, N_window, n_cal) + # (2) EPOCH must match the non-calibration series (the alignment bug) + d_epoch = abs(float(cal.epoch) - float(base.epoch)) + # (3) every block must reproduce the baseline rholm (cal factor == 1) + block_err = 0.0 + for c in range(n_cal): + blk = cal.data.data[c * N_window:(c + 1) * N_window] + block_err = max(block_err, float(np.max(np.abs(blk - base.data.data)))) + flag_e = "OK" if d_epoch < 1e-9 else "**EPOCH MISMATCH**" + flag_b = "OK" if block_err < 1e-6 else "**BLOCK MISMATCH**" + print("%s %s : |delta epoch|=%.3e %s max|block-baseline|=%.3e %s" % ( + det, pair, d_epoch, flag_e, block_err, flag_b)) + if d_epoch >= 1e-9 or block_err >= 1e-6: + ok = False + +assert ok, "calmarg precompute alignment MISMATCH (epoch and/or block data)" +print("\nPASS: calmarg precompute is time-aligned with the baseline (epoch + per-block data).") diff --git a/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/config.py b/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/config.py index 171852982..f9e067032 100644 --- a/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/config.py +++ b/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/config.py @@ -208,12 +208,24 @@ def validate_config(cfg) -> None: if not isinstance(_get(arch, "n-samples-per-job"), int) or _get(arch, "n-samples-per-job") <= 0: raise ValueError("arch.n-samples-per-job must be a positive integer.") - # post: at least one of (coords-fit) must be set + # post: must have at least one fit dim AND at least one MC sampling dim. + # Fit basis = coords-fit + coords-implied; MC basis = coords-fit + coords-nofit. + # Pre-decoupling this only required coords-fit, because the fit basis was + # forced to equal the MC basis -- now an EOS-style "fit in a transformed + # basis" config can legally have empty coords-fit (everything routed via + # coords-implied + coords-nofit through the coordinate plugin). post = _get(cfg, "post") - if not _get(post, "coords-fit"): + has_fit = bool(_get(post, "coords-fit")) or bool(_get(post, "coords-implied")) + has_samp = bool(_get(post, "coords-fit")) or bool(_get(post, "coords-nofit")) + if not has_fit: raise ValueError( - "post.coords-fit must list at least one parameter " - "(e.g. 'x y z')." + "post: must list at least one fit dimension " + "(coords-fit or coords-implied; e.g. 'x y z' or 'u v w')." + ) + if not has_samp: + raise ValueError( + "post: must list at least one MC sampling dimension " + "(coords-fit or coords-nofit; e.g. 'x y z')." ) # init: must have either file or generation set diff --git a/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/coords.py b/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/coords.py index bc039846a..ccd6b18d7 100644 --- a/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/coords.py +++ b/MonteCarloMarginalizeCode/Code/RIFT/hyperpipe/coords.py @@ -177,12 +177,22 @@ def from_strings( ``coords_nofit`` : "delta_mc s1z s2z" (optional) ``likelihood_factor``: (module, function, ini) (any element may be None) """ - params = parse_parameter_list(coords_fit) - ranges = parse_range_string(coords_sample) - unknown = set(ranges) - set(params) + params = parse_parameter_list(coords_fit) + implied = parse_parameter_list(coords_implied) + nofit = parse_parameter_list(coords_nofit) + ranges = parse_range_string(coords_sample) + # coords-sample provides INTEGRATION ranges, so it has to cover every + # name in the MC SAMPLING basis -- which is coords-fit + coords-nofit. + # (implied names are fit-only and don't need a sample range.) Pre- + # decoupling this was just coords-fit because nofit/implied were + # rarely used and the sampling basis was forced to equal the fit + # basis; now we have to allow the nofit names too. + sampling_basis = set(params) | set(nofit) + unknown = set(ranges) - sampling_basis if unknown: raise ValueError( - f"coords-sample names a parameter not in coords-fit: {sorted(unknown)!r}" + f"coords-sample names a parameter not in coords-fit or " + f"coords-nofit: {sorted(unknown)!r}" ) lf: Optional[Tuple[str, Optional[str], Optional[str]]] = None if likelihood_factor: @@ -195,8 +205,8 @@ def from_strings( name=name or None, parameters=params, parameter_ranges=ranges, - implied=parse_parameter_list(coords_implied), - nofit=parse_parameter_list(coords_nofit), + implied=implied, + nofit=nofit, likelihood_factor=lf, ) @@ -210,13 +220,30 @@ def validate(self, strict_import: bool = False) -> None: environment (singularity image, OSG worker) and not necessarily on the submit host. """ - if not self.parameters: - raise ValueError("HyperCoordSpec requires at least one fitting parameter.") - missing = [p for p in self.parameters if p not in self.parameter_ranges] + # The fit basis is coords-fit + coords-implied; the sampling basis is + # coords-fit + coords-nofit. Both must be non-empty for the run to + # make sense. Pre-decoupling this only required coords-fit -- now + # an "EOS-style fit in a transformed basis" config can legally have + # empty coords-fit (everything goes through implied + nofit). + if not self.parameters and not self.implied: + raise ValueError( + "HyperCoordSpec requires at least one fit dimension " + "(coords-fit or coords-implied)." + ) + if not self.parameters and not self.nofit: + raise ValueError( + "HyperCoordSpec requires at least one MC sampling dimension " + "(coords-fit or coords-nofit)." + ) + # Every name in the SAMPLING basis must have an integration range + # (the integrator reads prior_range_map[p] for p in low_level_coord_names). + sampling_names = list(self.parameters) + list(self.nofit) + missing = [p for p in sampling_names if p not in self.parameter_ranges] if missing: raise ValueError( - f"No integration range supplied for parameter(s): {missing!r}. " - "Every entry in coords-fit must appear in coords-sample." + f"No integration range supplied for sampling parameter(s): " + f"{missing!r}. Every entry in coords-fit and coords-nofit must " + "appear in coords-sample." ) for p, (lo, hi) in self.parameter_ranges.items(): if not lo < hi: @@ -266,7 +293,9 @@ def to_parameter_args(self) -> str: bits.append(f"--parameter-implied {p}") for p in self.nofit: bits.append(f"--parameter-nofit {p}") - for p in self.parameters: + # Integration ranges cover the MC SAMPLING basis (parameters + nofit). + # Implied coordinates are fit-only and don't have a sampling range. + for p in list(self.parameters) + list(self.nofit): lo, hi = self.parameter_ranges[p] bits.append( f"--integration-parameter-range {p}:[{self._fmt_num(lo)},{self._fmt_num(hi)}]" @@ -290,12 +319,16 @@ def to_post_args(self) -> str: def to_puff_args(self, force_away: float = 0.03, puff_factor: float = 0.5) -> str: """Emit the puff-stage arg block. - By default we puff in every fitting parameter; this is what every - existing hyperpipe example does. Extra flags can be appended by the - caller. + The puff lane reads / writes grid files in the data-file column + basis, which is the MC sampling basis (coords-fit + coords-nofit). + Pre-decoupling this only emitted --parameter for coords-fit because + the sampling basis was forced to equal the fit basis; once those + diverge (EOSPosterior with --parameter-implied for a transformed + fit basis), the puff lane must continue to operate on the data- + file columns -- i.e. coords-fit + coords-nofit. """ bits = [f"--force-away {force_away}", f"--puff-factor {puff_factor}"] - for p in self.parameters: + for p in list(self.parameters) + list(self.nofit): bits.append(f"--parameter {p}") return " ".join(bits) @@ -304,8 +337,12 @@ def to_test_args(self, method: str = "JS", threshold: float = 0.05) -> str: Mirrors the args_test.txt pattern from the Gaussian demo: ``--parameter x --parameter y --parameter z --method JS --threshold 0.05`` + + Like the puff lane, this operates on the SAMPLING basis (coords-fit + + coords-nofit) -- the convergence-test driver reads grid / posterior + files whose columns are in the sampling basis. """ - bits = [f"--parameter {p}" for p in self.parameters] + bits = [f"--parameter {p}" for p in list(self.parameters) + list(self.nofit)] bits.append(f"--method {method}") bits.append(f"--threshold {threshold}") return " ".join(bits) diff --git a/MonteCarloMarginalizeCode/Code/RIFT/integrators/MonteCarloEnsemble.py b/MonteCarloMarginalizeCode/Code/RIFT/integrators/MonteCarloEnsemble.py index ae1d2cd6e..133661090 100755 --- a/MonteCarloMarginalizeCode/Code/RIFT/integrators/MonteCarloEnsemble.py +++ b/MonteCarloMarginalizeCode/Code/RIFT/integrators/MonteCarloEnsemble.py @@ -95,7 +95,8 @@ class integrator: ''' def __init__(self, d, bounds, gmm_dict, n_comp, n=None, prior=None, - user_func=None, proc_count=None, L_cutoff=None, use_lnL=False,return_lnI=False,gmm_adapt=None,gmm_epsilon=None,tempering_exp=1,temper_log=False,lnw_failure_cut=None): + user_func=None, proc_count=None, L_cutoff=None, use_lnL=False,return_lnI=False,gmm_adapt=None,gmm_epsilon=None,tempering_exp=1,temper_log=False,lnw_failure_cut=None, + tempering_adapt=False, ess_target=None, ess_floor=None): # if 'return_lnI' is active, 'integral' holds the *logarithm* of the integral. # user-specified parameters self.d = d @@ -152,6 +153,22 @@ def __init__(self, d, bounds, gmm_dict, n_comp, n=None, prior=None, self.cumulative_p_s = self.xpy.empty(0) self.tempering_exp=tempering_exp self.temper_log=temper_log + # --- ESS-based tempering self-protection / self-tuning ------------- + # tempering_adapt: choose the refit exponent each chunk so the + # effective sample size of the refit weights hits ess_target + # (the user exponent becomes a cap, not a requirement). + # Always on (any settings): if the user exponent would leave the + # refit with ESS < ess_floor, the exponent is clamped down for that + # refit only. The evidence integral never uses these weights. + self.tempering_adapt = tempering_adapt + # largest number of mixture components in any group (for the floor) + if isinstance(n_comp, dict): + _k_max = max([v for v in n_comp.values()]) if len(n_comp)>0 else 1 + else: + _k_max = n_comp if n_comp else 1 + self.ess_floor = ess_floor if ess_floor else max(10.0, 2.0*_k_max) + self.ess_target = ess_target if ess_target else max(50.0, 0.05*self.n) + self.tempering_exp_running = tempering_exp # last exponent actually used if L_cutoff is None: self.L_cutoff = -1 else: @@ -191,17 +208,118 @@ def _sample(self): self.sample_array[:,dim] = temp_samples[:,index] index += 1 + def _log_ess(self, log_w): + """log of the Kish effective sample size of log-weights log_w.""" + return 2.*_xpy_logsumexp(log_w) - _xpy_logsumexp(2.*log_w) + + def _solve_tempering_exp(self, lnL, log_pq): + """ + Choose the tempering exponent beta for THIS refit from the effective + sample size of beta*lnL + log_pq (log_pq = ln p - ln p_s). + + - tempering_adapt: bisect so ESS(beta) ~ self.ess_target, with + beta <= beta_max = max(1, tempering_exp). As the proposal converges + the lnL spread across the cloud shrinks and beta rises automatically. + - otherwise: keep the user exponent unless ESS(user) < self.ess_floor, + in which case bisect down to the floor (pure safety net; cannot + crash the fit regardless of settings). + Returns (beta, log_ess_at_beta). + """ + ln_floor = self.xpy.log(self.ess_floor) + ln_target = self.xpy.log(self.ess_target) + beta_user = self.tempering_exp + if self.tempering_adapt: + beta_hi = max(1.0, beta_user) + ln_goal = ln_target + else: + beta_hi = beta_user + ln_goal = ln_floor + if self._log_ess(beta_user*lnL + log_pq) >= ln_floor: + return beta_user, self._log_ess(beta_user*lnL + log_pq) + # ESS is (near-)monotone decreasing in beta for peaked lnL; bisect. + if self._log_ess(beta_hi*lnL + log_pq) >= ln_goal: + return beta_hi, self._log_ess(beta_hi*lnL + log_pq) + lo, hi = 0.0, beta_hi + for _ in range(40): + mid = 0.5*(lo+hi) + if self._log_ess(mid*lnL + log_pq) >= ln_goal: + lo = mid + else: + hi = mid + return lo, self._log_ess(lo*lnL + log_pq) + def _train(self): sample_array, value_array, sampling_prior_array = self.xpy.copy(self.sample_array), self.xpy.copy(self.value_array), self.xpy.copy(self.sampling_prior_array) if self.use_lnL: lnL = value_array else: lnL = self.xpy.log(value_array+regularize_log_scale) - - log_weights = self.tempering_exp*lnL + self.xpy.log(self.prior_array) - sampling_prior_array + + # drop NaN evaluations up front (a NaN poisons every logsumexp below); + # -inf is fine (zero weight) and is kept. + prior_array = self.prior_array + mask_ok = ~self.xpy.isnan(lnL) + if not bool(self.xpy.all(mask_ok)): + sample_array = sample_array[mask_ok] + lnL = lnL[mask_ok] + sampling_prior_array = sampling_prior_array[mask_ok] + prior_array = prior_array[mask_ok] + + # replace -inf lnL (zero likelihood) by a finite very-low value: + # beta=0 would otherwise produce 0*(-inf)=NaN in the tempered weights + if not bool(self.xpy.all(self.xpy.isfinite(lnL))): + lnL_min = self.xpy.min(self.xpy.where(self.xpy.isfinite(lnL), lnL, self.xpy.inf)) + if not bool(self.xpy.isfinite(lnL_min)): + lnL_min = 0.0 + lnL = self.xpy.where(self.xpy.isfinite(lnL), lnL, lnL_min - 1000.) + + # ln p - ln p_s (NOTE: log of the sampling prior. The legacy code + # subtracted the *raw* sampling_prior_array from a log-quantity.) + log_pq = self.xpy.log(self.xpy.maximum(prior_array, 1e-300)) \ + - self.xpy.log(self.xpy.maximum(sampling_prior_array, 1e-300)) + + # ESS-protected/self-tuned tempering exponent for this refit + beta, log_ess = self._solve_tempering_exp(lnL, log_pq) + self.tempering_exp_running = beta + log_weights = beta*lnL + log_pq + adapt_mode = 'beta' + # Honest ESS of the FULL posterior weights (beta=1): how reachable the + # posterior is from the current proposal cloud. + log_ess1 = self._log_ess(lnL + log_pq) if self.temper_log: log_weights = self.xpy.log(self.xpy.maximum(lnL,1e-5)) - + elif self.tempering_adapt and bool(log_ess1 < self.xpy.log(self.ess_floor)): + # BOOTSTRAP (rank-elite refit, cross-entropy-method style): while + # the posterior is out of reach of the cloud, beta-tempered + # honest weights equilibrate the proposal near the PRIOR (the + # solver keeps beta tiny while the cloud is broad, and the + # -ln p_s correction cancels incremental concentration) -- the + # proposal never localizes, exactly the eff_samp~1 stall seen in + # the ILE SNR sequence. Rank weights are lnL-scale-free and + # compound: fit the top-k samples BY lnL (prior/proposal + # corrected within the elite set), so the threshold ratchets up + # every chunk like the AV sampler's volume shrinking. Hand back + # to the honest beta-solver once ESS(beta=1) clears the floor, + # after which the refit target smoothly becomes L*p (beta->1). + k_elite = int(min(max(self.ess_target, self.ess_floor), len(lnL)//2)) + if k_elite >= 2: + gamma = self.xpy.sort(lnL)[-k_elite] + neg_inf = -self.xpy.inf*self.xpy.ones(lnL.shape) + log_weights = self.xpy.where(lnL >= gamma, log_pq, neg_inf) + adapt_mode = 'elite' + else: + # If even beta=0 leaves too few effective samples (proposal/prior + # pathologies), skip this refit: keep the current proposal rather + # than fit garbage (breadcrumb item 2). + if self.xpy.exp(log_ess) < min(self.ess_floor, self.d + 2): + print(" GMM refit skipped: ESS {:.1f} too low even untempered ".format(float(self.xpy.exp(log_ess)))) + return + if getattr(self, '_verbose_diag', False): + print(" GMM adapt[{}]: mode={} beta={:.3g} ESS_beta={:.1f} ESS_1={:.3g} max_lnL={:.1f}".format( + self.iterations, adapt_mode, float(beta), + float(self.xpy.exp(log_ess)), float(self.xpy.exp(log_ess1)), + float(self.xpy.max(lnL)))) + for dim_group in self.gmm_dict: # iterate over grouped dimensions if self.gmm_adapt: if (dim_group in self.gmm_adapt): @@ -209,8 +327,11 @@ def _train(self): continue new_bounds = self.xpy.empty((len(dim_group), 2)) new_bounds = self.bounds[dim_group] + if len(new_bounds.shape) < 2: + # 1-d group with flat bounds (per-dim default): GMM expects (d,2) + new_bounds = self.xpy.array([new_bounds]) model = self.gmm_dict[dim_group] - temp_samples = self.xpy.empty((self.n, len(dim_group))) + temp_samples = self.xpy.empty((len(sample_array), len(dim_group))) index = 0 for dim in dim_group: temp_samples[:,index] = sample_array[:,dim] @@ -308,6 +429,7 @@ def integrate(self, func, min_iter=10, max_iter=20, var_thresh=0.0, max_err=10, tripwire_epsilon = kwargs["tripwire_epsilon"] if "tripwire_epsilon" in kwargs else 0.001 self.use_lnL = use_lnL self.return_lnI = return_lnI + self._verbose_diag = verbose # per-chunk adaptation diagnostics in _train err_count = 0 cumulative_eval_time = 0 diff --git a/MonteCarloMarginalizeCode/Code/RIFT/integrators/gaussian_mixture_model.py b/MonteCarloMarginalizeCode/Code/RIFT/integrators/gaussian_mixture_model.py index 23d5a241d..05fdf1874 100755 --- a/MonteCarloMarginalizeCode/Code/RIFT/integrators/gaussian_mixture_model.py +++ b/MonteCarloMarginalizeCode/Code/RIFT/integrators/gaussian_mixture_model.py @@ -95,8 +95,46 @@ def _xpy_logsumexp(a, axis=None): _xpy_eigvals = cupy.linalg.eigvalsh _xpy_eig = cupy.linalg.eigh else: - _xpy_eigvals = np.linalg.eigvals - _xpy_eig = np.linalg.eig + # Symmetric routines on CPU as well: the inputs are covariance/correlation + # matrices. eigvalsh/eigh are faster, return real eigenvalues (no spurious + # complex output from round-off asymmetry), and match the GPU path. + _xpy_eigvals = np.linalg.eigvalsh + _xpy_eig = np.linalg.eigh + + +def _near_psd_impl(x, epsilon, xpy): + ''' + Shared, hardened nearest-PSD projection for covariance matrices. + + Never raises on degenerate input: non-finite entries or non-positive + variances are repaired with an epsilon-scaled diagonal fallback before + the (symmetric, eigh-based) projection, and the projection loop is + bounded. Inputs are in normalized [-1,1] coordinates so an O(epsilon) + diagonal is always a meaningful scale. + ''' + n = x.shape[0] + # repair non-finite entries: they cannot reach the eigensolver + if not bool(xpy.all(xpy.isfinite(x))): + diag = xpy.diag(x).copy() + diag = xpy.where(xpy.isfinite(diag) & (diag > 0), diag, epsilon*xpy.ones(n)) + x = xpy.diag(diag) + # floor non-positive variances so the correlation rescaling is defined + diag = xpy.diag(x) + if bool(xpy.any(diag <= 0)): + floor = xpy.maximum(diag, epsilon) + x = x + xpy.diag(floor - diag) + x = 0.5 * (x + x.T) # symmetrize: eigh assumes it, round-off breaks it + for _ in range(10): # bounded: the legacy `while True` could spin forever + var_list = xpy.sqrt(xpy.diag(x)) + y = x / (var_list[:, None] * var_list[None, :]) + if bool(xpy.min(_xpy_eigvals(y)) > epsilon): + return x + eigval, eigvec = _xpy_eig(y) + val_psd = xpy.maximum(eigval, epsilon) + near_corr = eigvec @ xpy.diag(val_psd) @ eigvec.T + near_cov = near_corr * (var_list[:, None] * var_list[None, :]) + x = 0.5 * (near_cov.real + near_cov.real.T) + return x def gpu_logpdf(x, mean, cov, xpy): @@ -170,9 +208,18 @@ def __init__(self, k, max_iters=100, tempering_coeff=1e-8,adapt=None): self.identity_convert_togpu = identity_convert_togpu def _initialize(self, n, sample_array, log_sample_weights=None): - p_weights = self.xpy.exp(log_sample_weights - self.xpy.max(log_sample_weights)).flatten() - p_weights[self.xpy.isnan(p_weights)] = 0 # zero out the nan weights - p_weights /= self.xpy.sum(p_weights) + if log_sample_weights is None: + log_sample_weights = self.xpy.zeros(n) + finite_max = self.xpy.max(self.xpy.where(self.xpy.isfinite(log_sample_weights), log_sample_weights, -self.xpy.inf)) + if not bool(self.xpy.isfinite(finite_max)): + finite_max = 0.0 # no finite weights at all: fall back to uniform + p_weights = self.xpy.exp(log_sample_weights - finite_max).flatten() + p_weights[~self.xpy.isfinite(p_weights)] = 0 # zero out the nan/inf weights + w_sum = self.xpy.sum(p_weights) + if not bool(w_sum > 0): + p_weights = self.xpy.ones(n) + w_sum = 1.0 * n + p_weights /= w_sum self.means = sample_array[self.xpy.random.choice(n, self.k, p=p_weights.astype(sample_array.dtype)), :] self.covariances = [self.xpy.identity(self.d)] * self.k self.weights = self.xpy.ones(self.k) / self.k @@ -210,26 +257,49 @@ def _e_step(self, n, sample_array, log_sample_weights=None): def _m_step(self, n, sample_array): ''' - Maximization step + Maximization step. + + Works in the log domain: self.p_nk holds *log* responsibilities + (including the normalized log sample weights). Normalizing within + each component via logsumexp BEFORE exponentiating keeps the + means/covariances well-defined even when the raw weights span + thousands of nats (high-SNR refits): the dominant responsibilities + are O(1) by construction instead of all underflowing to zero. ''' - p_nk = self.xpy.exp(self.p_nk) - weights = self.xpy.sum(p_nk, axis=0) # weight of a single component + log_p_nk = self.p_nk + # per-component log total responsibility (log of the old `weights`) + log_w = _xpy_logsumexp(log_p_nk, axis=0) for index in range(self.k): if self.adapt[index]: - # (16.1.6) - w = weights[index] # should be 1 for a single component, note - p_k = p_nk[:,index] - mean = self.xpy.sum(self.xpy.multiply(sample_array, p_k[:,self.xpy.newaxis]), axis=0) - mean /= w - self.means[index] = mean - # (16.1.6) + if not bool(self.xpy.isfinite(log_w[index])): + # component received zero/non-finite weight: keep previous params + continue + # responsibilities normalized within this component: sum to 1 + r_k = self.xpy.exp(log_p_nk[:,index] - log_w[index]) + mean = self.xpy.sum(self.xpy.multiply(sample_array, r_k[:,self.xpy.newaxis]), axis=0) diff = sample_array - mean - cov = self.xpy.dot((p_k[:,self.xpy.newaxis] * diff).T, diff) / w - self.covariances[index] = self._near_psd(cov) + cov = self.xpy.dot((r_k[:,self.xpy.newaxis] * diff).T, diff) + # Guard BEFORE _near_psd (breadcrumb item 1): a degenerate weighted + # covariance (all responsibility on ~1 sample, ESS < d+1) or any + # non-finite entry must not reach the eigensolver. Keep the + # previous covariance (identity at init) and only update the mean. + ess_k = 1.0 / self.xpy.sum(r_k**2) + cov_ok = bool(self.xpy.all(self.xpy.isfinite(cov))) \ + and bool(self.xpy.trace(cov) > 0) \ + and bool(ess_k >= self.d + 1) + self.means[index] = mean + if cov_ok: + self.covariances[index] = self._near_psd(cov) # (16.17) - weights /= self.xpy.sum(p_nk[:,self.adapt]) - weights /= self.xpy.sum(weights) - self.weights = weights + # mixture weights via logsumexp over ALL components (the legacy + # double normalization cancels to exactly this softmax) + log_w_safe = self.xpy.where(self.xpy.isfinite(log_w), log_w, -self.xpy.inf*self.xpy.ones(self.k)) + log_norm = _xpy_logsumexp(log_w_safe) + if bool(self.xpy.isfinite(log_norm)): + weights = self.xpy.exp(log_w_safe - log_norm) + w_sum = self.xpy.sum(weights) + if bool(w_sum > 0) and bool(self.xpy.all(self.xpy.isfinite(weights))): + self.weights = weights / w_sum def _tol(self, n): @@ -243,35 +313,8 @@ def _near_psd(self, x): ''' Calculates the nearest postive semi-definite matrix for a correlation/covariance matrix ''' - n = x.shape[0] - var_list = self.xpy.array([self.xpy.sqrt(x[i,i]) for i in range(n)]) - # Use broadcasting for y instead of nested list comprehension - y = x / (var_list[:, None] * var_list[None, :]) - while True: - epsilon = self.epsilon - if self.xpy.min(_xpy_eigvals(y)) > epsilon: - return x - - var_list = self.xpy.array([self.xpy.sqrt(x[i,i]) for i in range(n)]) - y = x / (var_list[:, None] * var_list[None, :]) - - eigval, eigvec = _xpy_eig(y) - val = self.xpy.maximum(eigval, epsilon) - vec = eigvec + return _near_psd_impl(x, self.epsilon, self.xpy) - # Standard PSD projection: - val_psd = self.xpy.maximum(eigval, epsilon) - near_corr = vec @ self.xpy.diag(val_psd) @ vec.T - - # Re-scale back to covariance - near_cov = near_corr * (var_list[:, None] * var_list[None, :]) - - if self.xpy.isreal(near_cov).all(): - break - else: - x = near_cov.real - return near_cov - def fit(self, sample_array, log_sample_weights): ''' Fit the model to data @@ -324,10 +367,15 @@ class gmm: More sophisticated implementation built on top of estimator class ''' - def __init__(self, k, bounds, max_iters=1000,epsilon=None,tempering_coeff=1e-8): + def __init__(self, k, bounds, max_iters=1000,epsilon=None,tempering_coeff=1e-8,memory_factor=3.0): self.k = k self.bounds = bounds self.max_iters = max_iters + # update() merge memory: the old model enters the merge with weight + # min(N, memory_factor*M) instead of the full cumulative N, so an + # early bad fit cannot accumulate unbounded inertia (the proposal can + # always recover within ~memory_factor chunks). + self.memory_factor = memory_factor self.means = [None] * k self.covariances =[None] * k self.weights = [None] * k @@ -403,8 +451,13 @@ def _match_components(self, new_model): def _merge(self, new_model, M): ''' - Merge corresponding components of new model and old model + Merge corresponding components of new model and old model. + + The old model's merge weight is capped at memory_factor*M (bounded + memory): with the legacy cumulative self.N an early bad fit dominated + every later merge and the proposal could never recover. ''' + N_merge = min(self.N, self.memory_factor * M) if self.memory_factor else self.N order = self._match_components(new_model) for i in range(self.k): j = order[i] @@ -414,22 +467,22 @@ def _merge(self, new_model, M): temp_cov = new_model.covariances[j] old_weight = self.weights[i] temp_weight = new_model.weights[j] - denominator = (self.N * old_weight) + (M * temp_weight) + denominator = (N_merge * old_weight) + (M * temp_weight) - mean = (self.N * old_weight * old_mean) + (M * temp_weight * temp_mean) + mean = (N_merge * old_weight * old_mean) + (M * temp_weight * temp_mean) mean /= denominator - cov1 = (self.N * old_weight * old_cov) + (M * temp_weight * temp_cov) + cov1 = (N_merge * old_weight * old_cov) + (M * temp_weight * temp_cov) cov1 /= denominator # outer product for means - cov2 = (self.N * old_weight * self.xpy.outer(old_mean, old_mean)) + (M * temp_weight * self.xpy.outer(temp_mean, temp_mean)) + cov2 = (N_merge * old_weight * self.xpy.outer(old_mean, old_mean)) + (M * temp_weight * self.xpy.outer(temp_mean, temp_mean)) cov2 /= denominator cov = cov1 + cov2 - self.xpy.outer(mean, mean) cov = self._near_psd(cov) - weight = denominator / (self.N + M) + weight = denominator / (N_merge + M) self.means[i] = mean self.covariances[i] = cov @@ -439,32 +492,15 @@ def _near_psd(self, x): ''' Calculates the nearest postive semi-definite matrix for a correlation/covariance matrix ''' - n = x.shape[0] - var_list = self.xpy.array([self.xpy.sqrt(x[i,i]) for i in range(n)]) - y = x / (var_list[:, None] * var_list[None, :]) - while True: - epsilon = self.epsilon - if self.xpy.min(_xpy_eigvals(y)) > epsilon: - return x - - var_list = self.xpy.array([self.xpy.sqrt(x[i,i]) for i in range(n)]) - y = x / (var_list[:, None] * var_list[None, :]) - - eigval, eigvec = _xpy_eig(y) - val_psd = self.xpy.maximum(eigval, epsilon) - near_corr = eigvec @ self.xpy.diag(val_psd) @ eigvec.T - near_cov = near_corr * (var_list[:, None] * var_list[None, :]) - if self.xpy.isreal(near_cov).all(): - break - else: - x = near_cov.real - return near_cov + return _near_psd_impl(x, self.epsilon, self.xpy) def update(self, sample_array, log_sample_weights=None): ''' Updates the model with new data without doing a full retraining. ''' - self.tempering_coeff /= 2 + # halve the covariance regularizer but FLOOR it: an unbounded decay + # (the legacy behavior) eventually leaves sharp refits unregularized + self.tempering_coeff = max(self.tempering_coeff / 2, 1e-12) new_model = estimator(self.k, self.max_iters, self.tempering_coeff) # Filter non-finite @@ -523,10 +559,15 @@ def score(self, sample_array,assume_normalized=True): my_cdf = norm(loc=mean_cpu, scale=sigma_cpu).cdf normalization_constant += w * (my_cdf(bounds_norm_cpu[1]) - my_cdf(bounds_norm_cpu[0])) + # Floors: a sharply-truncated component can drive the mvnun + # normalization to 0 (0/0 -> NaN), and exactly-zero scores later become + # log(0) = -inf in the integrator's weights. 1e-300 keeps the log + # finite without affecting any sample that carries real weight. + normalization_constant = max(float(normalization_constant), 1e-300) scores /= normalization_constant vol = self.xpy.prod(self.bounds[:,1] - self.bounds[:,0]) scores *= (2.0**self.d) / vol - return scores + return self.xpy.maximum(scores, 1e-300) def sample(self, n, use_bounds=True): ''' diff --git a/MonteCarloMarginalizeCode/Code/RIFT/integrators/mcsamplerEnsemble.py b/MonteCarloMarginalizeCode/Code/RIFT/integrators/mcsamplerEnsemble.py index a7075c1bc..ac433f43a 100755 --- a/MonteCarloMarginalizeCode/Code/RIFT/integrators/mcsamplerEnsemble.py +++ b/MonteCarloMarginalizeCode/Code/RIFT/integrators/mcsamplerEnsemble.py @@ -183,6 +183,9 @@ def setup(self,n_comp=None,**kwargs): gmm_epsilon = kwargs['gmm_epsilon'] if "gmm_epsilon" in kwargs else None L_cutoff = kwargs["L_cutoff"] if "L_cutoff" in kwargs else None tempering_exp = kwargs["tempering_exp"] if "tempering_exp" in kwargs else 1.0 + tempering_adapt = kwargs["tempering_adapt"] if "tempering_adapt" in kwargs else False + ess_target = kwargs["ess_target"] if "ess_target" in kwargs else None + ess_floor = kwargs["ess_floor"] if "ess_floor" in kwargs else None lnw_failure_cut = kwargs["lnw_failure_cut"] if "lnw_failure_cut" in kwargs else None nmax = kwargs["nmax"] if "nmax" in kwargs else 1e6 neff = kwargs["neff"] if "neff" in kwargs else 1000 @@ -225,7 +228,8 @@ def setup(self,n_comp=None,**kwargs): bounds[dims]=bounds_here self.integrator = monte_carlo.integrator(dim, bounds, gmm_dict, n_comp, n=self.n, prior=self.calc_pdf, - user_func=integrator_func, proc_count=proc_count,L_cutoff=L_cutoff,gmm_adapt=gmm_adapt,gmm_epsilon=gmm_epsilon,tempering_exp=tempering_exp) + user_func=integrator_func, proc_count=proc_count,L_cutoff=L_cutoff,gmm_adapt=gmm_adapt,gmm_epsilon=gmm_epsilon,tempering_exp=tempering_exp, + tempering_adapt=tempering_adapt, ess_target=ess_target, ess_floor=ess_floor) def update_sampling_prior(self,ln_weights, n_history,tempering_exp=1,log_scale_weights=True,floor_integrated_probability=0,external_rvs=None,**kwargs): rvs_here = self._rvs @@ -333,6 +337,11 @@ def integrate(self, func, *args,**kwargs): gmm_epsilon = kwargs['gmm_epsilon'] if "gmm_epsilon" in kwargs else None L_cutoff = kwargs["L_cutoff"] if "L_cutoff" in kwargs else None tempering_exp = kwargs["tempering_exp"] if "tempering_exp" in kwargs else 1.0 + # --adapt-adapt: ESS-self-tuned refit exponent (previously silently + # dropped by this sampler; only mcsampler/mcsamplerGPU honored it) + tempering_adapt = kwargs["tempering_adapt"] if "tempering_adapt" in kwargs else False + ess_target = kwargs["ess_target"] if "ess_target" in kwargs else None + ess_floor = kwargs["ess_floor"] if "ess_floor" in kwargs else None lnw_failure_cut = kwargs["lnw_failure_cut"] if "lnw_failure_cut" in kwargs else None max_err = kwargs["max_err"] if "max_err" in kwargs else 10 @@ -387,7 +396,8 @@ def integrate(self, func, *args,**kwargs): bounds[dims]=bounds_here integrator = monte_carlo.integrator(dim, bounds, gmm_dict, n_comp, n=n, prior=self.calc_pdf, - user_func=integrator_func, proc_count=proc_count,L_cutoff=L_cutoff,gmm_adapt=gmm_adapt,gmm_epsilon=gmm_epsilon,tempering_exp=tempering_exp) + user_func=integrator_func, proc_count=proc_count,L_cutoff=L_cutoff,gmm_adapt=gmm_adapt,gmm_epsilon=gmm_epsilon,tempering_exp=tempering_exp, + tempering_adapt=tempering_adapt, ess_target=ess_target, ess_floor=ess_floor) if not direct_eval: func = self.evaluate if use_lnL: @@ -427,7 +437,10 @@ def integrate(self, func, *args,**kwargs): self._rvs['integrand'] = self.identity_convert(value_array) if bFairdraw and not(n_extr is None): - n_extr = int(self.xpy.min([n_extr,1.5*eff_samp,1.5*neff])) + # scalars: use Python min on floats. self.xpy.min([list]) fails on cupy + # (cupy.min has no list overload -> "'list' object has no attribute 'min'"), + # which crashed the GMM sampler's fairdraw export on GPU. + n_extr = int(min(float(n_extr), 1.5*float(eff_samp), 1.5*float(neff))) print(" Fairdraw size : ", n_extr) if return_lnI: ln_wt = integrator.cumulative_values diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/DESIGN.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/DESIGN.md new file mode 100644 index 000000000..bee0915ec --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/DESIGN.md @@ -0,0 +1,174 @@ +# jax_gp — design & rationale + +Status: living design doc for the AD-compatible likelihood-interpolation effort. +See `README.md` for usage; this file is the *why* and the long-term plan. + +## Problem + +CIP fits `lnL` over intrinsic parameters, then samples it. Two long-standing +limitations motivate this work: + +1. **Scaling of the exact GP.** The legacy sklearn `GaussianProcessRegressor` + path is O(N³); at our scale (N ~ 2·10⁴–5·10⁴, d ~ 8–12) it is intractable, so + the pipeline leans on `--cap-points`, `--lnL-offset`, pooling, etc. +2. **Non-differentiable export.** We export lnL grid evaluations (or a NumPy + sklearn pickle / a black-box NN) and hope. Downstream users who need a + *differentiable* `lnL(θ)` cannot get one. + +## The honest competitive picture: RF is the bar + +Random forests (RIFT's default `fit_rf`, ExtraTrees) are **excellent** and very +hard to beat on raw timing+accuracy: + +- robust, no hyperparameter tuning, +- cover the full lnL dynamic range, +- extremely fast to fit. + +Measured, GW170817 (good coords, lnLmax−20 band, ~7.6k train pts): + +| method | peak-wtd rmse | fit time | differentiable? | +|---|---|---|---| +| RF (ExtraTrees) | 1.84 | **5 s** | **no** | +| SVGP (this work) | 1.64 | 432 s | **yes** | + +So the GP is *not* going to win a straight timing+accuracy race against RF, and +we should stop pretending it will. **The GP's reason to exist is different:** + +1. **Fewer function evaluations.** RF needs large training volumes to be accurate; + a GP reaches comparable accuracy from far fewer points. Our training points are + *expensive to make* (each is an ILE evaluation), so "accurate with less data" + directly reduces the dominant cost — even if the fit itself is slower. Reducing + function evaluations is an explicit program goal, not a side benefit. +2. **Automatic differentiation.** RF is piecewise-constant — no usable gradient. + The GP gives a smooth, exact `∇lnL`. This is the capability several downstream + applications *cannot do without* (below). + +The GP's extra fit cost is acceptable **if** it is part of a cohesive AD tooling +set we need for other reasons — which it is. + +## Why AD is critical (the applications driving this) + +These are the concrete reasons a differentiable `lnL(θ)` is a hard requirement, +not a nicety: + +1. **Population inference (AD / numpyro).** Hierarchical/population analyses built + on AD frameworks (numpyro, etc.) consume *individual-event* likelihoods and need + their **derivatives**. A differentiable per-event lnL surrogate (this package's + export) is the missing piece that lets per-event results flow into gradient-based + population inference. +2. **Differentiable sampling (replace brute-force MC).** Our integrals are + currently done with brute-force Monte Carlo. A **derivative-aware sampler** + (HMC/NUTS, normalizing-flow / SVI, Langevin, …) is *enormously* more efficient, + especially at **high SNR** where the posterior is sharp and MC is wasteful. To + get there: + - **CIP:** needs (a) **AD fits** — delivered here by the GP — and (b) a + **derivative-aware sampler** wired into the CIP integrator (pending). + - **ILE:** needs (a) porting one of the likelihoods from **cupy → JAX** so the + extrinsic integral itself is differentiable, and (b) a derivative-aware + sampler. Larger effort; longer-term. + +## Downselect decision (current) + +**RFF is the default jax method**; SVGP and exact are kept as **backstop / +validation** code. Rationale: no scaling reason favors SVGP (both are O(N M²) time, +O(N M) memory, linear in N), and RFF is empirically faster *and* more accurate on +our benchmarks with the smaller constant factor (fixed feature basis; no k-means or +inducing-point optimization). SVGP is retained because (a) it is the principled +inducing-point method and a useful cross-check, and (b) its adaptive inducing points +and calibrated predictive variance are the natural seed for a future +uncertainty-driven **active-learning / sample-placement** loop (reduce function +evaluations further). The GP will **not** replace RF in the standard CIP stack; +its value is the AD use cases below. + +## Long-term roadmap + +Ordered roughly by dependency, not committed dates: + +1. **(done)** AD-compatible GP fits + a self-contained differentiable export + (`export.py`). RFF / SVGP / exact behind one interface; heteroscedastic noise; + ARD; ILE `.net` loader; good fit coordinates. +2. **Tune the scalable GP.** Make SVGP competitive with RFF (more/better inducing + points, more steps, possibly input warping for sharp peaks); decide the + per-regime default from `benchmark/scaling_study.py`. +3. **Population-inference hookup.** Provide a clean numpyro-friendly loader so an + exported per-event lnL drops into an AD population model. (Adjacent: the + `gwkokab` work.) +4. **Derivative-aware sampler (clean JAX path).** Done in prototype: + `applications/jax_cip.py`. Default sampler is **flow-as-sampling-model**, not + NUTS-only: a flowMC run trains a normalizing flow q(theta) while gradient MALA + moves explore, then we draw i.i.d. from the flow and importance-weight by + exp(lnL + log_prior - log_q). This *decouples ESS from MCMC autocorrelation* + (efficiency becomes a flow-training knob — "beat on the flow") and yields the + **evidence Z** for free. On GW170817 this gave physical tidal values (lambda1 ~ + 240-470) and logZ ~ 490 where NUTS-only stalled (ESS ~ 3, spurious lambdas); + ESS rose ~1%→6% as flow training increased. Next: physical-support priors + (lambda>=0, |delta_mc|<1) via a constrained base so IS can't chase GP + extrapolation; iterate the flow toward ESS~1; prior Jacobians for science-grade + evidence. Built as a separate path, not a legacy CIP retrofit (Architecture above). +5. **ILE cupy→JAX likelihood port + derivative-aware extrinsic sampler.** The big + one: a differentiable ILE likelihood end-to-end. Enables differentiable + extrinsic marginalization, not just a differentiable surrogate of its output. + +## Architecture: clean JAX path, not a CIP retrofit + +Decision: rather than shoehorn JAX into the ~2000-line legacy CIP (a Rube-Goldberg +risk), the JAX use cases live in a **separate, clean pure-JAX path** +(`applications/jax_cip.py`): load ILE → good fit coords → differentiable RFF fit → +numpyro NUTS → physical-parameter AD hook. The legacy CIP stays the production +path, untouched. The two will inevitably drift in feature parity, but the JAX use +cases are *qualitatively different* (gradient sampling, AD population inference) +and far simpler to manage on their own than as a bolt-on. Coordinate transforms +are reimplemented in pure JAX (`coordinates.py`) but validated against the legacy +NumPy source of truth; the legacy transforms are not modified. + +**Sampling lesson (demonstrated on GW170817):** sample in the *decorrelated fit* +coordinates (mu1, mu2, …), with the prior bounded to the training support. There +NUTS mixes well and returns a sensible posterior (LambdaTilde ~ 226). Sampling in +raw physical (m1, m2) reintroduces the curved chirp-mass degeneracy the mu coords +exist to remove, and a diagonal-metric sampler chokes (ESS ~ 3). The +physical-parameter *gradient* is still produced (the population-inference hook), +evaluated at in-support points — it just isn't used as the sampling geometry. + +**Performance caveat:** all timings in this repo were measured on an old CPU box +with a weak GPU. Production hardware is far faster; treat the numbers as *relative* +(method-vs-method, scaling trends), not absolute targets. + +## Design choices (and why) + +- **JAX**, not PyTorch: composability with numpyro / optax / the population-inference + stack, and a clean pure-function export users can `jax.grad`. +- **Hand-rolled SGPR** (Titsias collapsed bound), not gpjax: no version coupling to + a fast-moving lib against jax 0.10; the predictive mean stays a transparent, + exportable closed form. +- **Whitening + ARD + good coordinates** carry most of the fit-quality water. + Coordinate choice (e.g. `mu1,mu2,delta_mc,LambdaTilde,DeltaLambdaTilde` for BNS) + matters more than the interpolator; see README "Coordinates matter". +- **Heteroscedastic noise**: ILE lnL has a per-point MC error (the `sigma/L` + column); using it cuts held-out error by 10–70× on noisy data. Always on when + errors are available. +- **Export is differentiable in fit coordinates.** Pushing derivatives back to raw + physical parameters needs a JAX reimplementation of the coordinate transforms — + deferred (would also be needed for a fully-AD CIP sampler). + +## Current benchmark findings + +- **Scaling sweep (16 cells: {svgp,rff} × d∈{8,12} × N∈{2k,20k} × {correlated_gaussian, + sharp_peak}) — RFF beats SVGP on every cell and is faster.** E.g. correlated_gaussian + d=12 N=20k: RFF rmse 0.21 vs SVGP 0.58; sharp_peak d=12 N=20k: RFF 0.08 vs SVGP 0.25. + Both improve with N and have grad-cosine ≈ 0.99–1.0. This is why RFF is the default; + making SVGP competitive is roadmap step 2. +- On GW170817 the GP modestly beats RF on accuracy but not on speed — we buy AD (and the + few-evaluations regime), not raw throughput. + +## AD applications (built — see `applications/`) + +The use cases that justify the GP, now prototyped: + +- **`export_artifact.py`** — packages a real ILE run into a self-contained + differentiable lnL (e.g. a 30 KB GW170817 RFF export in BNS coords). This is the + "package it sanely" product downstream users consume. +- **`diff_sampler.py`** — gradient-based sampling (numpyro NUTS, flowMC) of the + fitted lnL. On a sharp synthetic posterior NUTS recovers the analytic answer at + **~300× higher ESS-per-lnL-evaluation** than a matched-budget gradient-free + random walk — a direct demonstration of the high-SNR sampling payoff (roadmap + step 4, here on the surrogate; the real CIP-integrator swap is the next step). diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/HANDOFF.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/HANDOFF.md new file mode 100644 index 000000000..46c793b54 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/HANDOFF.md @@ -0,0 +1,116 @@ +# jax_gp — handoff / resume here + +Snapshot for whoever picks this up next (branch `rift_O4d_junior_interp_jax` → `junior`). +Read alongside `DESIGN.md` (rationale/roadmap) and `README.md` (usage). Dev env: the +`gwkokab` conda env; run with `PYTHONPATH=<.../MonteCarloMarginalizeCode/Code>`. + +## The goal (reframed) +Produce a **differentiable** lnL surrogate for CIP that is **unbiased** vs the +production CIP+RF posterior, to **PE standards: JS divergence ~ few×10⁻³ bits** on +every 1D marginal (mc, LambdaTilde/tides, q/delta_mc, spins). GP training is **slow → +offline / handoff use**; we are NOT trying to beat RF on speed, only to validate the +exported surrogate is good enough and not biased. + +## What is SOLID (don't relitigate) +- **Architecture is right: `quadgp` = quadratic/Fisher core + GP residual** (`quad_gp.py`). + A pure GP cannot match a razor-sharp quasi-quadratic peak to few-% width; the + quadratic captures the exact Fisher curvature (sharp eigen-dirs only, via + `keep_curv_frac`), the GP fits the smooth residual. **This nails mc: width 6.5e-5 vs + truth 6.9e-5 (JS ~0.035) in every config.** That was the wall; it's cleared. +- **Constrain GP lengthscales** (`svgp.py`/`exact.py`): free hyperopt over-smooths + (lengthscale runs long); we clip the ARD lengthscale to ~the near-peak width. Keep this. +- **Sampler for a SHARP surrogate = importance sampling, NOT the flow.** A flow can't + learn a 5e-5 peak in a 3e-3 box (ESS→5). `sample_gaussian_is` with a peak-matched + Gaussian proposal works (ESS ~hundreds). +- **Morisaki (mu) frame proposal** (`_muframe_proposal`): build the proposal covariance + in fit coords (well-conditioned; the physical low-level cov is near-singular in mc), + pull back via the JAX Jacobian `P_low = J^T C_fit^-1 J + diag(1/prior_var)`. Sample + stays physical — no inverse transform. +- **Benchmark + JS harness done.** 10× CIP RF+AV in `applications/benchmark_condor/` + (matches the paper's `posterior_samples-6.dat`, LambdaTilde 343±183 ✓). 50k pooled + samples cached at `/home/oshaughn/jaxcip_benchmark/out/cip_rf_*.xml.gz`. + `applications/compare.py` computes JS (bits) with a bootstrap stderr. + **CAVEAT (RO): the RF benchmark is a REFERENCE, not assumed-converged ground truth.** + Those runs were short — harvested to accumulate likelihoods, not tuned for perfect + sampling convergence. So part of the residual JS may be the *benchmark*, not our + surrogate. SAFEST validation (TODO): re-benchmark BOTH the surrogate and a fresh RF + run from the SAME initial lnL grid (use the 'large' grid from the tabular runs or + Atul's runs), so the comparison isolates surrogate-vs-RF with no grid/convergence + confound. Current numbers are good enough for downstream teams to code against. + +## Current JS (quadgp + svgp-residual 10.8k, vs benchmark) — sampler comparison +| sampler | mc | delta_mc | s1z | s2z | lambda1 | lambda2 | LambdaTilde | +|---|---|---|---|---|---|---|---| +| mu-frame gaussian-IS | 0.035 | **0.356** | 0.053 | 0.067 | 0.152 | — | 0.116 | +| **mu-frame NUTS (`nuts-mu`)** | **0.023** | **0.056** | **0.008** | **0.008** | **0.016** | **0.015** | **0.028** | + +**NUTS-in-mu DONE and it confirmed the diagnosis: it was the sampler.** Every marginal +improved; delta_mc (the worst IS regression) 6.4×; spins ~8e-3 ≈ at the bar. From +catastrophic IS (0.04–0.36) → uniformly small JS. Still NOT uniformly few×10⁻³ (mc 0.023, +LambdaTilde 0.028, delta_mc 0.056) — but the residual now behaves **surrogate/data-limited, +not sampler-limited** (spins, where the surrogate is best, are at the bar; gap is in the mc +width ~16% too broad + the broadest dirs). Single-seed JS is noisy on razor-sharp mc. + +## What `nuts-mu` is (DONE; `sample_nuts_muframe` in applications/jax_cip.py) +NUTS in **low-level** coords (output + box are natural; the 5→6 fit→low map isn't +invertible so we can't sample in fit coords), **preconditioned** with the mu-frame +covariance: `_muframe_proposal` builds a well-conditioned cov in the fit frame and pulls +it back to low-level (`P_low = Jᵀ C_fit⁻¹ J + diag(1/prior_var)`). numpyro reparam's the +Uniform box as `theta = lo+(hi-lo)·sigmoid(u)`, so we seed the dense mass matrix with that +cov **mapped into u-space** by the local sigmoid Jacobian (`imm = S⁻¹ gcov S⁻¹`, +`S = (hi-lo)·s·(1-s)` at the peak); init at the peak; adapt_mass_matrix=True re-adapts. +Unit-tested on a 4-orders-of-mag-anisotropic correlated Gaussian (ESS ~4–5k/6k, 0 div, +σ recovered 0.5%). Run via `--sampler nuts-mu --num-chains N`. Demo: +`demo/rift/export_likelihoods/`. + +### Next step (highest value): close the last factor (now surrogate/data, not sampler) +1. **More/uncapped data + `--quadgp-residual exact` cross-check** at the largest tractable + N — does the inducing-point approx cost accuracy at scale? Push mc width + LambdaTilde down. +2. **Reduce the 14 divergences** (raise target_accept; check if they cluster at box edges in + the weakly-constrained dirs). +3. **Multi-seed JS + bootstrap** for publication-grade error bars (single-seed mc JS is noisy). +4. **Tighten the quadratic-core mc localization** (residual ~16% mc width is the dominant + remaining bias on the sharpest direction). + +## How to run +```bash +cd .../MonteCarloMarginalizeCode/Code +P=/home/oshaughn/.conda/envs/gwkokab/bin/python +# surrogate + sampler -> posterior XML +PYTHONPATH="$PWD" $P -m RIFT.interpolators.jax_gp.applications.jax_cip \ + --fname /home/oshaughn/all.net \ + --parameter delta_mc --parameter-implied mu1 --parameter-implied mu2 \ + --parameter-implied LambdaTilde --parameter-implied DeltaLambdaTilde \ + --parameter-nofit mc --parameter-nofit s1z --parameter-nofit s2z \ + --parameter-nofit lambda1 --parameter-nofit lambda2 \ + --mc-range '[1.196,1.199]' --chi-max 0.05 \ + --cap-points 12000 --jax-fit-method quadgp --quadgp-residual svgp \ + --n-features 800 --n-opt-steps 250 --sampler gaussian \ + --fname-output-samples /tmp/jaxcip_out +# JS vs the cached benchmark (per param) +PYTHONPATH="$PWD" $P -m RIFT.interpolators.jax_gp.applications.compare \ + --a /tmp/jaxcip_out.xml.gz \ + --b '/home/oshaughn/jaxcip_benchmark/out/cip_rf_*.xml.gz' --param mc +``` +Fast surrogate-only diagnostic (fit + importance-sample widths, ~3 min): see the +inline scripts in the session, or fit `get_interpolator("quadgp")(...)` and IS in +low-level coords (order: mc, delta_mc, s1z, s2z, lambda1, lambda2 — don't swap mc/delta_mc). + +## Key files +- `quad_gp.py` — quadratic core + GP residual (mc-exact). Export DONE: overrides + `export_state`/`from_state` to embed the nested residual model (its arrays under a + `_resid_` prefix, its meta under `meta["resid_meta"]`); round-trips cross-process and + stays `jax.grad`-able after load (see `test_export_roundtrip_quadgp_*`). +- `svgp.py` / `exact.py` — constrained-lengthscale GPs (peak-matched bounds). +- `coordinates.py` — pure-JAX Morisaki/tidal transforms (validated vs lalsimutils). +- `applications/jax_cip.py` — pipeline: tree-ring downselect, fit, samplers + (`flow|nuts|gaussian|mixture`), `_muframe_proposal`, legacy-CIP-compatible CLI + output. +- `applications/compare.py` — JS metric. `applications/benchmark_condor/` — RF+AV fleet. + +## Don'ts (learned the hard way) +- Don't let GP hyperparameters fit freely (over-smooths). Constrain the lengthscale. +- Don't use RFF for IS targets — it rings/overshoots → IS ESS collapse. +- Don't expect a flow to sample a razor-sharp surrogate (use peak-matched IS / NUTS-in-mu). +- Don't trust ESS as the success metric — use JS vs the benchmark. ESS-good can be 12× biased. +- Don't re-derive prior bounds from the data — trust the CLI ranges (the grid extends past + the prior on purpose). The naive Gauss+flow mixture is a dead end without box-normalization. diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/README.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/README.md new file mode 100644 index 000000000..034cd3583 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/README.md @@ -0,0 +1,169 @@ +# jax_gp — scalable, AD-compatible likelihood interpolation for RIFT/CIP + +JAX-based likelihood interpolators with a shared interface, built for two goals +the legacy CIP fit path does not meet: + +1. **Scale.** Fit lnL over N ~ 2·10⁴–5·10⁴ points in d ~ 8–12 without the O(N³) + blow-up of the exact sklearn GP (and without the `--cap-points` workarounds + that throw information away). +2. **Differentiable export.** Produce a self-contained `lnL(θ)` that downstream + users can `jax.grad` through — replacing "dump the lnL grid and hope." + +This is an **optional** subpackage. It is never imported by the production CIP +path unless a `gp-jax-*` fit method is selected, so the JAX stack is not required +for normal RIFT operation. Install the extra with: + +``` +pip install RIFT[jax-interp] # jax, optax, equinox, tinygp +``` + +## Methods + +| method (`--fit-method`) | class | approach | cost | notes | +|---|---|---|---|---| +| `gp-jax-rff` | `RFFInterpolator` | random Fourier features GP (Bayesian linear regression in feature space) | O(N M²) | cheapest export; weaker on sharp/non-stationary peaks | +| `gp-jax-svgp` | `SVGPInterpolator` | Titsias collapsed sparse GP (SGPR), M inducing points (k-means init) | O(N M²) | scalable production-regime default; hand-rolled pure JAX | +| `gp-jax-exact` | `ExactGPInterpolator` | exact GP (tinygp) | O(N³) | accuracy reference baseline only — not for production N | +| `gp-jax-quadgp`| `QuadraticPlusGPInterpolator` | Fisher-curvature quadratic core + GP residual (`--quadgp-residual {svgp,exact,rff}`) | core of the chosen residual | PE-grade on razor-sharp peaks (mc-exact); the surrogate to export for downstream use | + +All three use **ARD** (per-dimension lengthscales), which matters for the strongly +anisotropic / curved degeneracies in real lnL surfaces (mc–eta, lambda1–lambda2): +on the curved `banana_ridge` benchmark it lifts gradient-cosine-vs-truth from +~0.78 to ~0.95–1.0. + +## Shared interface (`interface.BaseInterpolator`) + +```python +from RIFT.interpolators.jax_gp import get_interpolator +model = get_interpolator("svgp")().fit(X, y, y_errors=yerr) + +fn = model.predict_callable() # callable(np.ndarray[n,d]) -> np.ndarray[n] (CIP contract) +v, g = model.lnL_and_grad(theta) # differentiable lnL + gradient at one point +gfn = model.grad_fn() # jitted pure-JAX theta -> (lnL, grad) +``` + +Every method fits on per-dimension *whitened* coordinates and centered targets; +because whitening is affine, JAX threads the chain rule through it, so gradients +come back in physical (fit-coordinate) units automatically. 64-bit JAX is enabled +on import — lnL gradients need it. + +## Differentiable export (`export.py`) + +```python +from RIFT.interpolators.jax_gp import export +export.save(model, "myfit", coord_names=["mc","eta","chi_eff", ...]) +# -> myfit.npz + myfit.meta.json + +loaded = export.load("myfit") # pure-JAX, differentiable +import jax; grad = jax.grad(loaded.lnL_physical)(theta) +``` + +The exported lnL is differentiable in the *fit* coordinates the GP was trained on +(recorded in `meta.json` as `coord_names`). Pushing the derivative back to raw +physical parameters would require a JAX reimplementation of CIP's coordinate +transforms — out of scope here, noted as future work. + +All four methods export, including `quadgp`: it nests a *full* residual model, so it +embeds that sub-model in the same bundle (residual arrays namespaced `_resid_*` in the +`.npz`, residual meta under `meta["resid_meta"]`) — a single `export.save`/`export.load` +round-trips the whole quadratic-core + residual surrogate and it stays `jax.grad`-able. + +## Coordinates matter (a lot) + +The single biggest lever on fit quality is **which coordinates you fit in** — not +the interpolator. A stationary GP fits a far simpler surface in RIFT's decorrelated +coordinates than in raw `(m1, m2, s1z, s2z, lambda1, lambda2)`. For low-mass / BNS +(e.g. GW170817), fit in `mu1, mu2, delta_mc, LambdaTilde, DeltaLambdaTilde`, where +`mu1, mu2` are Morisaki's orthogonalized PN-phase combinations (`RIFT/misc/tools.py`) +that decorrelate chirp-mass/mass-ratio/spin at low mass. + +In CIP this is done with the parameter flags, e.g.: + +``` +--parameter-implied mu1 --parameter-implied mu2 --parameter-nofit mc \ +--parameter delta_mc --parameter-nofit s1z --parameter-nofit s2z \ +--parameter-implied LambdaTilde --parameter-implied DeltaLambdaTilde \ +--parameter-nofit lambda1 --parameter-nofit lambda2 +``` + +`--parameter` / `--parameter-implied` form the fit coordinates the GP sees; +`--parameter-nofit` are sampled but only used to derive the implied ones. The +conversion is `lalsimutils.convert_waveform_coordinates`; CIP applies it before the +fit. For offline experiments, `benchmark/datasets.to_fit_coordinates` / +`BNS_FIT_COORDS` wrap the same transform. On GW170817, naive→good coordinates cut +SVGP peak-weighted rmse from ~3.0 to ~2.2 nats (and the dimension from 6 to 5). + +## Use from CIP + +``` +util_ConstructIntrinsicPosterior_GenericCoordinates.py \ + --fit-method gp-jax-svgp \ + --fit-save-jax myfit \ + ... (usual CIP args) ... +``` + +`--fit-save-jax ` writes the differentiable export alongside the run. +`--fit-load-gp ` reloads such an export instead of refitting. + +## Benchmarking + +`benchmark/harness.py` sweeps `{method} × {N} × {truth}` against synthetic +ground-truth lnL functions with analytic gradients (in `truth_functions.py`), +scoring value RMSE, peak-weighted RMSE, gradient accuracy, and fit/predict time +(`metrics.py`). The exact GP also serves as the yardstick when no analytic truth +is available. + +``` +python -m RIFT.interpolators.jax_gp.benchmark.harness --d 8 --N 2000 8000 \ + --methods rff svgp exact +``` + +`benchmark/scaling_study.py` sweeps the scalable methods across dimension, N, and +surface shape, writing one JSON line per cell as it completes (crash-resilient): + +``` +python -m RIFT.interpolators.jax_gp.benchmark.scaling_study --out study.jsonl \ + --dims 8 12 --N 2000 20000 --methods svgp rff +``` + +`benchmark/datasets.py` (`load_ile_net`) loads real RIFT ILE `.net`/`.composite` +output so methods can be tested on production lnL surfaces, not only synthetics. + +## Applications (`applications/`) + +The use cases that justify a differentiable surrogate over the faster non-AD RF: + +- `export_artifact.py` — CLI packaging a RIFT ILE `.net` file into a self-contained + differentiable lnL export (good fit coordinates, MC errors, peak cut): + ``` + python -m RIFT.interpolators.jax_gp.applications.export_artifact \ + --net all.net --out gw170817_rff --coords bns + ``` +- `diff_sampler.py` — gradient-based sampling (numpyro NUTS / flowMC) of the fitted + lnL vs a gradient-free baseline, demonstrating the high-SNR efficiency win: + ``` + python -m RIFT.interpolators.jax_gp.applications.diff_sampler --demo synthetic + ``` +- `jax_cip.py` — a clean, standalone pure-JAX posterior path (load ILE → fit → + NUTS) with a **legacy-CIP-compatible CLI** for pipeline-level hot-swap: it + accepts the legacy argument surface inclusively, honours `--parameter*` / + `--fname` / cuts / `--n-output-samples`, and writes the same + `--fname-output-samples` XML + `_lnL.dat`. (I/O contract done; NUTS sample + quality on sharp real surrogates is the science-grade follow-up — see DESIGN.md.) + ``` + python -m RIFT.interpolators.jax_gp.applications.jax_cip --fname all.net \ + --parameter delta_mc --parameter-implied mu1 mu2 LambdaTilde DeltaLambdaTilde \ + --parameter-nofit mc s1z s2z lambda1 lambda2 --fname-output-samples out + ``` + +See `DESIGN.md` for the rationale and roadmap, and `applications/{ARTIFACT,SAMPLER_NOTES}.md`. + +## Tests + +``` +python -m RIFT.interpolators.jax_gp.test_interpolators +``` + +Checks recovery of a known target, AD-vs-finite-difference gradient agreement, +and export round-trip (including `jax.grad` on the reloaded model) for the three +base GP methods plus `quadgp` with both `svgp` and `exact` residuals. diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/__init__.py b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/__init__.py new file mode 100644 index 000000000..b8287a9af --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/__init__.py @@ -0,0 +1,56 @@ +""" +jax_gp : robust, scalable, AD-compatible likelihood interpolators for RIFT/CIP. + +This is an *optional* subpackage. It is never imported by the production CIP +path unless the user explicitly selects a ``gp-jax-*`` fit method, so the JAX +dependency stack (jax, optax, equinox, tinygp, ...) is not required for normal +operation. + +All interpolators share the contract defined in ``interface.BaseInterpolator``: + + model = SomeInterpolator(...).fit(X, y, y_errors=...) + fn = model.predict_callable() # callable(np.ndarray[n,d]) -> np.ndarray[n] + v, g = model.lnL_and_grad(theta) # differentiable lnL + gradient at one point + +``predict_callable`` is the drop-in for the existing CIP fit dispatch (every +``fit_*`` there returns exactly such a callable). ``lnL_and_grad`` -- and the +pure-JAX closure behind it -- is what makes the exported surrogate differentiable +for downstream users. + +We enable 64-bit JAX on import: lnL spans a large dynamic range and single +precision is not adequate for faithful gradients. This is process-global, but +only takes effect once this opt-in subpackage is imported. +""" +from __future__ import annotations + +import jax as _jax + +if not _jax.config.read("jax_enable_x64"): + _jax.config.update("jax_enable_x64", True) + +from .interface import BaseInterpolator # noqa: E402 + +__all__ = ["BaseInterpolator"] + +# Method classes are imported lazily by name to avoid importing every backend +# (and its heavier deps, e.g. tinygp) when only one is needed. +def get_interpolator(name): + """Return the interpolator class registered under ``name``. + + Names mirror the CIP ``--fit-method`` values (without the ``gp-jax-`` prefix): + ``rff``, ``exact``, ``svgp``. + """ + name = name.lower().replace("gp-jax-", "").replace("gp_jax_", "") + if name in ("", "rff"): # RFF is the default jax method + from .rff import RFFInterpolator + return RFFInterpolator + if name == "exact": + from .exact import ExactGPInterpolator + return ExactGPInterpolator + if name == "svgp": + from .svgp import SVGPInterpolator + return SVGPInterpolator + if name in ("quadgp", "quad", "quad-gp"): + from .quad_gp import QuadraticPlusGPInterpolator + return QuadraticPlusGPInterpolator + raise ValueError("Unknown jax_gp interpolator: {!r}".format(name)) diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/ARTIFACT.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/ARTIFACT.md new file mode 100644 index 000000000..37e8b78ab --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/ARTIFACT.md @@ -0,0 +1,51 @@ +# Differentiable lnL artifact + +`export_artifact.py` turns a RIFT ILE `.net` file (the per-point Monte-Carlo +likelihood evaluations CIP normally consumes) into a small, self-contained, +**differentiable** surrogate for the marginalized log-likelihood `lnL(theta)`. + +Build one with: + +```bash +python export_artifact.py --net /path/to/all.net --out /tmp/gw170817_rff --coords bns +``` + +## What the artifact is + +A random-Fourier-feature (RFF) regression of `lnL` over the ILE samples, fit in +*decorrelated fit coordinates* using the per-point ILE Monte-Carlo errors +(`sigma_lnL`) as observation noise. Only the informative high-likelihood region is +kept (`lnL > max - lnL_offset`), de-duplicated and sigma-cut exactly as CIP does. +The result is a pure-JAX `lnL(theta)` that `jax.grad` / `jax.value_and_grad` +differentiate out of the box — no RIFT or lalsimutils import is needed to *load* it. + +## File format + +The export is two files sharing a base path: + +- `.npz` — whitening vectors (`x_mean`, `x_std`) plus the RFF parameters + (frequencies, weights), as NumPy arrays. +- `.meta.json` — schema/method/dimension, target centering/scaling + (`y_mean`, `y_std`), and `coord_names`: the names of the axes of `theta`. + +## Coordinate caveat (important) + +The artifact is differentiable in its **fit coordinates** — the list recorded as +`coord_names` in the meta — *not* in the raw physical parameters. For `--coords bns` +these are `('mu1','mu2','delta_mc','LambdaTilde','DeltaLambdaTilde')`; for +`--coords raw` they are the six raw params `(m1,m2,s1z,s2z,lambda1,lambda2)`. +A gradient in fit coordinates is what CIP works in; pushing it back to raw physical +parameters would require a JAX reimplementation of CIP's coordinate transforms, +which is deliberately out of scope. Always read `coord_names` before differentiating. + +## How a downstream user loads and differentiates it + +```python +import jax, jax.numpy as jnp +from RIFT.interpolators.jax_gp import export +model = export.load("/tmp/gw170817_rff") # reconstructs pure-JAX lnL +print(model.coord_names) # axis order of theta +theta = jnp.array([0., 0., 0., 300., 0.]) # a point in fit coordinates +lnL, grad = jax.value_and_grad(model.lnL_physical)(theta) +print(float(lnL), grad) # scalar lnL + finite gradient +``` diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/EXPORT_AT_SCALE.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/EXPORT_AT_SCALE.md new file mode 100644 index 000000000..6cb78b10a --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/EXPORT_AT_SCALE.md @@ -0,0 +1,201 @@ +# export_at_scale — ship + validate differentiable lnL artifacts from real run dirs + +`export_at_scale.py` points at a **real RIFT run directory**, exports a continuous, +`jax.grad`-able surrogate for that run's `all.net` intrinsic likelihood, and +**validates the exported artifact** by drawing a posterior from it and comparing the +marginals to the run's own CIP posterior. It scales the single-event +[`export_artifact`](ARTIFACT.md) / [`jax_cip`](../README.md) primitives to a whole +directory of events, locally or over HTCondor. + +At this stage it interpolates **only `all.net`** — the existing intrinsic ILE +deliverable. (Distance-grid export is a separate track.) + +Nothing is written back into the run directory; every output lands under +`--workroot/__/`. + +## What it does, per run + +1. **Discover.** Detect the `all.net` column layout (aligned / precessing / + tidal widths all differ — the trailing `lnL sigma_lnL ntot [neff]` columns are + tail-anchored), parse the active `args_cip_list.txt` for the fit parameters and + the prior box (`--mc-range`, `--eta-range`, `--chi-max`), and find the run's + latest `posterior_samples-.dat` (falling back to + `extrinsic_posterior_samples.dat`). +2. **Export the deliverable.** Fit a surrogate in dimension-agnostic physical fit + coordinates — `[mc, delta_mc]` plus whichever spin/tidal columns actually vary + (constant columns are dropped and recorded) — so the same path covers aligned, + BNS, and **precessing** runs with no hand-written coordinate transform. Save the + `.npz` + `.meta.json` bundle, **reload it**, and assert the reloaded `predict()` + matches and `jax.grad` is finite. +3. **Validate (apples-to-apples).** Draw a posterior *from the reloaded artifact* + whose target is the run's **actual** `lnL + ln prior` — using RIFT's own priors, + sampled in RIFT's own coordinates (spins in `(χ, cosθ, φ)` where the isotropic + prior is flat and there's no Cartesian 1/χ² singularity; the non-uniform mass + prior `mc_prior ∝ mc`, `eta_prior ∝ η^(−6/5)`; the `alignedspin-zprior` for + aligned runs) — by Gaussian importance sampling in low dimension or gradient-based + **NUTS** (using the artifact's `jax.grad` lnL) in high dimension. Then report the + Jensen–Shannon divergence of the `mc` / `q` / `chi_eff` marginals against the CIP + posterior, with an **ESS-based quality flag** so a sampling-limited result is never + mistaken for surrogate error. Writes `posterior_interp.dat`, `report.json`, + `summary.md`, and `marginals.png`. + +## Environment + +Runs in the `rift_ad_export` conda env (a clone of `rift_jax` so the shared env is +never modified): jax 0.9.2, numpyro, flowMC 0.6.0, tinygp, RIFT. `gaussian`/`nuts` +need only jax+numpyro; `--sampler flow` needs flowMC (this env has a 0.6.0-correct +flow sampler built into the tool — the legacy `jax_cip.sample_flow_is` breaks under +flowMC 0.6.0's keyword-only `Sampler` API). Always export `PYTHONPATH` to the RIFT +source tree. + +## Usage + +```bash +PY=~/.conda/envs/rift_ad_export/bin/python +export PYTHONPATH=/path/to/RIFT/MonteCarloMarginalizeCode/Code +M=RIFT.interpolators.jax_gp.applications.export_at_scale + +# inspect what discovery found (no work done) +$PY -m $M discover --run-dir /path/to/rundir + +# one run, immediately +$PY -m $M one --run-dir /path/to/rundir --workroot ./out + +# many runs locally +$PY -m $M batch --runs '/data/*/S*/rift*/' --workroot ./out + +# many runs as a condor DAG (sub templated from each run's own CIP.sub) +$PY -m $M batch --runs '/data/*/S*/rift*/' --workroot ./out --condor +condor_submit_dag ./out/condor/export_at_scale.dag +``` + +## Output layout + +``` +workroot/ + __/ + lnL_artifact.npz # the differentiable surrogate (load via jax_gp.export.load) + lnL_artifact.meta.json # coord_names, dropped-constant columns, provenance + posterior_interp.dat # posterior drawn FROM the reloaded artifact + report.json # full machine-readable report (fit + JS validation) + summary.md # human summary + JS table + marginals.png # interp-vs-CIP 1D marginal overlay + condor/ # (batch --condor) DAG + sub + per-job logs + batch_summary.json # (batch local) one line per run +``` + +## Key options + +| flag | default | meaning | +|---|---|---| +| `--method` | `quadgp` | surrogate: `quadgp` (PE-grade Fisher core + GP residual) · `svgp` (faster, low-D) · `rff` · `exact` | +| `--mass-coord` | `eta` | second mass coordinate to fit in: `eta` (Fisher-quadratic; correct `q`) · `delta_mc` | +| `--keep-curv-frac` | `0.01` | keep core eigen-curvature above this fraction of max (small ⇒ retains the gentle eta curvature) | +| `--sampler` | `auto` | validation sampler: `auto` (nuts if >3 fit dims, else gaussian) · `gaussian` · `nuts` · `flow` | +| `--n-samples` | 40000 | gaussian importance-sampling proposal draws | +| `--cap-points` | 8000 | stratified ("tree-ring") downselect of ILE points before the fit | +| `--n-features` | 256 | SVGP inducing points / RFF features | +| `--lnL-offset` | 40 | keep `lnL > max − offset` | +| `--no-plot` | — | skip `marginals.png` (used by condor jobs) | + +## Use cases / coverage + +| case | status | notes | +|---|---|---| +| **(a) precessing** (8-D: mc, eta, s1x..s2z) | ✅ supported | spins sampled in `(χ,cosθ,φ)`; all 11 marginals reported | +| **(b) aligned** (2–4-D: mc, eta, [s1z,s2z]) | ✅ supported | zero-spin / aligned-spin runs; constant spin columns auto-dropped | +| **(c) + distance export** (`*.dgrid` / `all_dgrid.dat`) | 🚧 detected, **not yet exported** | `discover_run` sets `has_dgrid`; the run's *intrinsic* `all.net` export still runs and validates. The (intrinsic + luminosity-distance) surrogate is the **next active track** — the dgrid data is still being produced. | + +The all-parameter JS (masses, `chi_eff`, `chiMinus`, cylindrical-polar spins) lets you +see immediately which physical direction a given run gets wrong — e.g. low-mass events +tend to stress the *aligned-spin* (`chi_eff`/`s1z`/`s2z`) direction. + +## Validated on + +| run | dims | method | sampler | ESS | JS mc | JS q | JS chi_eff | +|---|---|---|---|---|---|---|---| +| distance_grid_e2e (aligned, 2-D, mc∈[23,35]) | 2 | quadgp | gaussian | ~30000 | 0.008 | 0.011 | 0 (no spin) | +| S240426s v5PHM (precessing, 8-D, mc∈[30,90]) | 8 | quadgp | nuts | 2100 | 0.007 | 0.011 | 0.008 | + +All three marginals are PE-grade and apples-to-apples on both runs, using the defaults +(`--method quadgp --mass-coord eta --keep-curv-frac 0.01`). See the tuning note below +for why the fit coordinate (`eta`) is what makes `q` work. + +## Surrogate tuning (mc/q fidelity) — the fit-coordinate matters + +The mass-ratio (`q`) marginal is recovered correctly only when the surrogate is fit +in the variable the lnL **Fisher is actually quadratic in: `eta`, not `delta_mc`.** +Since `eta = ¼(1−delta_mc²)`, the curvature in `delta_mc` at the peak is suppressed by +`delta_mc*²` (it vanishes toward equal mass) — so in `delta_mc` the quadratic core +sees a *flat* direction it cannot capture, the GP residual must carry the whole `q` +falloff, over-smooths it, and the posterior grows a spurious low-q tail. + +Fix (now the **default**): fit in `eta` (`--mass-coord eta`) with +`--keep-curv-frac 0.01` so the core *retains* the now-real eta curvature, while still +**sampling in `delta_mc`** (smooth prior `∝ eta^(−6/5)`, no equal-mass singularity; +better NUTS geometry). On the 8-D precessing **S240426s** (mc∈[30,90]): + +| fit coord | keep_curv_frac | JS mc | JS q | JS chi_eff | q (interp/CIP) | +|---|---|---|---|---|---| +| delta_mc | 0.05 | 0.015 | 0.056 | 0.009 | 0.638 / 0.718 | +| **eta** | **0.01** | **0.007** | **0.011** | **0.008** | **0.694 / 0.718** | + +i.e. `q` JS **5×** better and `mc` **2×** better — all three now PE-grade and +apples-to-apples. The reported `holdout_rmse` is over the peak region (within 15 nats); +the eta quadratic core extrapolates steeply in the deep low-lnL tail (`holdout_rmse_all` +is large but that region has ~zero posterior weight). + +Other knobs (`--ls-lo-frac/--ls-hi-frac` smoothing length, `--n-features`, +`--cap-points`) are second-order once the fit coordinate is right; raising +`keep_curv_frac` (fewer core directions) *re-hides* the eta curvature. + +### Spin: the same lesson — fit in `(chi_eff, chiMinus)`, not `(s1z, s2z)` + +The 139-event O4b sweep showed the median event PE-grade on all 11 params, but +**low-mass events fail on aligned spin** (`chi_eff` JS median: mc 0–15 → 0.40, 15–30 → +0.07, 30–60 → 0.008 — a 50× mass gradient). Cause: low-mass systems measure aligned +spin sharply, and the well-measured `chi_eff` is a **diagonal ridge** in `(s1z, s2z)` +that an axis-aligned ARD GP + per-dimension-whitened quadratic core over-smooth. + +Fix (default `--spin-coord aligned_eff`): rotate the aligned-spin fit coordinates to +the Fisher principal axes `(chi_eff, chiMinus)` — short ARD lengthscale on the sharp +`chi_eff`, long on the broad `chiMinus` — while still **sampling** in the smooth +spherical spin coords (the per-body spin-sampling structure is taken from the *raw* +physics, not the fit-coordinate names — a subtlety: inferring it from `fit_names` +breaks once `s1z`/`s2z` are replaced by `chi_eff`/`chiMinus`). + +| event | mc | `cartesian` chi_eff JS | `aligned_eff` chi_eff JS | +|---|---|---|---| +| S250119cv | ~10 | 0.505 | **0.016** | +| S240413p | ~6.4 | 0.192 | **0.101** | +| S240426s | ~60 | 0.011 | **0.005** | + +No regression at high mass; the very-low-mass extreme (mc~6) is halved but remains the +hardest case. Across the 53 failing events `aligned_eff` improved 40 (15 now pass) and +the low-mass `chi_eff` median dropped 0.40→0.054 — **but it regressed 9 events** below +cartesian (real, not sampling noise). + +So the default is **`--spin-coord auto`**: fit *both* `aligned_eff` and `cartesian` and +keep the one with the lower peak-region holdout RMSE. Holdout RMSE reliably picks the +better coordinate (it selects `cartesian` exactly where `aligned_eff` would regress), +so `auto` captures the wins and is **never worse than `cartesian`** — e.g. S250119cv → +aligned_eff (chi_eff 0.016), S240513ei → cartesian (avoids the 0.61 regression). Cost is +~2× the fit (two fits, one sample). Defaults: `--method quadgp --mass-coord eta +--spin-coord auto --keep-curv-frac 0.01`. + +## Interpreting the JS divergence + +Because the validation now samples `exp(lnL + ln prior)` with **RIFT's own priors** +(mass-ratio prior ∝ η^(−6/5); isotropic uniform-magnitude spin prior, sampled in +`(χ,cosθ,φ)`; `alignedspin-zprior` when used), the comparison is apples-to-apples: +a non-zero JS reflects *surrogate* error, not a prior-convention mismatch. The report +carries a bootstrap stderr and an ESS quality flag; pool more CIP samples (raise +`--n-output-samples` upstream) if a small JS is statistics-limited. + + +JS is in bits (0 = identical marginals). For PE-grade agreement expect a few × 10⁻³ +to ~10⁻² bits on `mc`; weakly-constrained directions (`q`, `chi_eff`) are wider and +more sensitive to the number of independent samples in *both* posteriors — the +report carries a bootstrap stderr so you can tell when a large JS is just +statistics-limited (pool more CIP samples / raise `--n-output-samples` upstream). +``` diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/SAMPLER_NOTES.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/SAMPLER_NOTES.md new file mode 100644 index 000000000..2afd04f00 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/SAMPLER_NOTES.md @@ -0,0 +1,129 @@ +# Differentiable sampling of the jax_gp surrogate — design note + +`diff_sampler.py` demonstrates the concrete payoff of a **differentiable** +`lnL(theta)`: we can sample the (unnormalized) posterior with a **gradient-based** +sampler instead of brute-force Monte Carlo. This is roadmap item 4 in +`../DESIGN.md` ("derivative-aware sampler for CIP") exercised end-to-end on a +fitted surrogate. + +## Why gradient-based sampling (vs brute-force MC) + +CIP currently turns a fitted `lnL` into a posterior by **brute-force Monte +Carlo** (draw from a proposal, reweight by `exp(lnL)`). That works, but its +efficiency collapses as the posterior gets **sharp**: a proposal tuned to the +prior volume places almost all its samples where `lnL` is negligible, so the +effective sample size (ESS) per likelihood evaluation falls off a cliff. This is +exactly the **high-SNR** regime — the posterior occupies a tiny, often strongly +correlated sliver of parameter space. + +A gradient-aware sampler (HMC/NUTS, MALA, normalizing-flow MCMC) uses `∇lnL` to +walk *along* the posterior ridge rather than guessing. Where a random walk has +to shrink its step size to the width of the sharp peak (and then accepts almost +nothing), NUTS follows the gradient and adapts its trajectory length, keeping a +high acceptance rate and decorrelating fast. The GP is what makes this possible: +RF is piecewise-constant and has no usable gradient; the RFF GP exports a smooth, +exact `∇lnL`. + +## What the demo measures + +`demo_synthetic()` (d=5): + +1. builds a **known** sharp, correlated-Gaussian `lnL` (per-direction widths + ~0.05–0.2, random rotation — a sharp high-SNR-like peak), +2. fits an RFF surrogate to 3000 points (heteroscedastic noise on), +3. samples the surrogate with **NUTS** (`sample_nuts`) and with a **gradient-free + random-walk Metropolis** baseline (`sample_rwm`) given the **same lnL-evaluation + budget**, and +4. compares both recovered posteriors against the **analytic** Gaussian posterior + (the exact product of the known lnL Gaussian and the broad Normal prior). + +The "evaluation budget" is matched honestly: NUTS's cost is its total leapfrog +step count (each leapfrog step is one `lnL`+gradient evaluation, reported via +numpyro's `num_steps` extra field), and RWM is run for that same number of +proposals (one `lnL` evaluation each). + +### Measured numbers (CPU, gwkokab env, seed 0) + +| sampler | wall-clock | lnL evals | ESS (min/dim) | ESS / eval | posterior mean max\|z\| | cov rel-err | +|---|---|---|---|---|---|---| +| **NUTS** (gradient) | 7.6 s | 12 804 | 1868 | **0.146** | 0.05 | 0.08 | +| RWM (gradient-free) | 2.7 s | 12 804 | 3 | 0.0003 | 0.83 | 0.80 | +| flowMC (gradient, bonus) | ~30 s | n/a | — | — | 0.05 | 0.05 | + +Surrogate held-out RMSE: **0.008** lnL units (the RFF fit is essentially exact on +this smooth target). + +**Headline:** NUTS achieves **~570× higher ESS per lnL-evaluation** than the +gradient-free baseline on this sharp posterior, and recovers the analytic +posterior mean to ≤0.05σ in every dimension with an 8% covariance error. The +random-walk baseline, given the identical budget, has a **0.1% acceptance rate**, +≈3 effective samples, and a badly biased posterior (max 0.83σ mean error, 80% +covariance error) — the textbook failure mode of brute-force MC on a sharp peak, +and precisely the regime where the differentiable GP earns its (slower) fit cost. + +flowMC (normalizing-flow + MALA, also gradient-based) is wired up as a +best-effort bonus and recovers the posterior just as well (max\|z\|=0.05, 5% cov +error); it is heavier to set up and is **not** the required path. If its +constructor API drifts in a future release, `sample_flowMC` catches the exception +and skips with a logged note so the NUTS demo is never blocked. + +## Limitations (read before over-claiming) + +- **We sample the FIT, not the true likelihood.** The posterior recovered here is + the posterior of the *surrogate* `lnL`, including any GP fit error. On this + synthetic the surrogate is near-exact (RMSE 0.008), so surrogate error is not + the bottleneck; on real, noisier ILE data it will be, and the ESS gain must be + weighed against fit fidelity. +- **Fit coordinates, not raw physical parameters.** `lnL_physical` is + differentiable in the GP's *fit* coordinates (`model.coord_names`). Pushing the + gradient back to raw physical parameters needs a JAX reimplementation of CIP's + coordinate transforms — deliberately out of scope here (same caveat as + `export.py`). +- **Prior is a convenience, not the science prior.** The demo uses a broad Normal + around `x_mean` (scale `3·x_std`) to (a) localize the relevant region and (b) + keep NUTS in a well-behaved unconstrained space. A real run substitutes the + actual astrophysical prior; the analytic-posterior comparison accounts for this + exact broad-Normal prior so the recovery check is apples-to-apples. +- **Synthetic is Gaussian.** A Gaussian posterior is the friendly case for both + the analytic comparison and NUTS. Multimodal / heavy-tailed real posteriors are + where flowMC (global proposals) earns its keep over plain NUTS. + +## Path to a real CIP integrator swap + +1. Fit/export a real ILE `lnL` (the existing `export.save` / `export.load` path; + e.g. the GW170817 artifact). `diff_sampler.py --artifact ` already loads + such a bundle and runs NUTS on it (guarded by `export.exists`). +2. Replace the demo's broad-Normal prior with CIP's actual intrinsic-parameter + prior, expressed in fit coordinates (or add the JAX coordinate transform so the + prior can be stated physically). +3. Swap CIP's brute-force integrator for `sample_nuts` to draw posterior samples + and estimate the evidence (NUTS does not give the normalization directly — + pair with thermodynamic integration / bridge sampling, or use the flow's + density for an importance-sampling evidence estimate). +4. Validate against the existing brute-force CIP posterior on a few events before + making it a selectable integrator, then quantify the function-evaluation + savings — the GP's whole reason to exist (fewer expensive ILE evaluations) plus + this sampler's per-evaluation efficiency compound here. + +## API + +```python +from RIFT.interpolators.jax_gp import get_interpolator +from RIFT.interpolators.jax_gp.applications.diff_sampler import ( + sample_nuts, sample_rwm, sample_flowMC, demo_synthetic, +) + +model = get_interpolator("rff")(n_features=512, n_opt_steps=300).fit(X, y, y_errors=yerr) +res = sample_nuts(model, num_warmup=500, num_samples=2000) +# res["samples"], res["ess_min"], res["n_grad_evals"], res["wall_clock"], res["mean"], res["cov"] +``` + +Run the demo: + +```bash +cd MonteCarloMarginalizeCode/Code +PYTHONPATH="$PWD:$PYTHONPATH" python \ + RIFT/interpolators/jax_gp/applications/diff_sampler.py --demo synthetic +# optionally also sample a real exported artifact: +# ... --artifact /tmp/gw170817_rff +``` diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/__init__.py b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/__init__.py new file mode 100644 index 000000000..1b58ddd9a --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/__init__.py @@ -0,0 +1,6 @@ +""" +Applications that exploit the differentiable jax_gp likelihood export. + +These are the use cases that justify the GP over the (faster, non-AD) random +forest: gradient-based sampling and AD population inference. See ../DESIGN.md. +""" diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/README.md b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/README.md new file mode 100644 index 000000000..87f5ef01b --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/README.md @@ -0,0 +1,54 @@ +# CIP+RF+AV benchmark (Condor) for the jax_cip accuracy comparison + +The success metric for `jax_cip` is JS divergence of the 1D marginals (mc first) +against the **production CIP path: RF fit + Adaptive-Volume sampler**. A single CIP +run gives a few thousand samples — not enough for a reliable JS — so we launch it +**10×** and pool. SVGP/`jax_cip` is fast enough to produce its own samples without +this, so only the brute-force benchmark needs the fleet. + +## What's here +- `run_cip_rf_av.sh` — one CIP job (`--fit-method rf --sampler-method AV`) on the + GW170817 `.net`, in the same BNS coordinates/prior as `jax_cip` + (`--mc-range [1.196,1.199] --chi-max 0.05 --input-tides`). Writes `cip_rf_.xml.gz`. +- `cip_rf_benchmark.sub` — HTCondor submit, `queue 10`. + +## Resource footprint (measured on this box) +| config | peak RSS | wall | +|---|---|---| +| uncapped RF | 3.9 GB | ~70 s | +| `--cap-points 30000` | **1.25 GB** | ~76 s | + +So the submit requests `request_memory = 2048` (2 GB, comfortable margin over 1.25 GB) +and `request_cpus = 2` (threads pinned to 2 in the wrapper). The RF dominates memory; +capping points is what keeps the footprint low — keep the cap unless you need a +denser RF. + +## Run it +```bash +cd RIFT/interpolators/jax_gp/applications/benchmark_condor +# edit the absolute paths in cip_rf_benchmark.sub (RIFT_CODE / PYTHON / NET) for your site +mkdir -p out logs +condor_submit cip_rf_benchmark.sub +# 10 independent runs (CIP has no --seed, so each randomizes) -> out/cip_rf_{0..9}.xml.gz +``` + +## Compare (JS, mc first) +```bash +# pool all 10 benchmark runs on the B side; A side is the jax_cip output XML +python -m RIFT.interpolators.jax_gp.applications.compare \ + --a /path/to/jax_cip_out.xml.gz \ + --b out/'cip_rf_*.xml.gz' \ + --param mc +# repeat with --param delta_mc / lambda1 / s1z ... once mc looks good. +``` +`compare.py` prints JS in bits with a bootstrap stderr; if the stderr is comparable +to the JS, you are statistics-limited — add more benchmark runs (bump `queue`) and/or +draw more `jax_cip` samples. + +## Notes +- No `--seed` in CIP → independent streams per launch; that's why pooling 10 runs + accumulates valid statistics. +- The benchmark uses the SAME prior box as `jax_cip` (mc-range, chi-max, lambda) so the + JS reflects method differences, not prior differences. +- `RIFT_CODE` points at a stable checkout (not the ephemeral `.claude/worktrees` copy) + so the fleet keeps working after the dev worktree is cleaned up. diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/cip_rf_benchmark.sub b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/cip_rf_benchmark.sub new file mode 100644 index 000000000..0b586fec7 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/cip_rf_benchmark.sub @@ -0,0 +1,30 @@ +# HTCondor: 10x CIP (RF fit + AV sampler) for the GW170817 benchmark posterior. +# Accumulates independent statistics (CIP randomizes per launch -> 10 x 5000 = 5e4 +# benchmark samples) for the JS-vs-jax_cip accuracy comparison. +# +# Submit from this directory after editing the absolute paths below: +# mkdir -p out logs && condor_submit cip_rf_benchmark.sub +# then, once done: +# python -m RIFT.interpolators.jax_gp.applications.compare \ +# --a jax_cip_out.xml.gz --b out/'cip_rf_*.xml.gz' --param mc + +universe = vanilla +executable = run_cip_rf_av.sh +arguments = $(Process) + +# Measured peak RSS ~1.25 GB with --cap-points 30000 (3.9 GB uncapped); small margin. +request_memory = 2048 +request_cpus = 2 +request_disk = 2048 + +# Edit these absolute paths for your site: +environment = "RIFT_CODE=/home/oshaughn/research-projects-RIT/MonteCarloMarginalizeCode/Code PYTHON=/home/oshaughn/.conda/envs/gwkokab/bin/python NET=/home/oshaughn/all.net OUTDIR=$ENV(PWD)/out" + +getenv = False +should_transfer_files = NO + +output = logs/cip_rf_$(Process).out +error = logs/cip_rf_$(Process).err +log = logs/cip_rf.log + +queue 10 diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/run_cip_rf_av.sh b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/run_cip_rf_av.sh new file mode 100755 index 000000000..135303aff --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/benchmark_condor/run_cip_rf_av.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# One CIP + RF + AV intrinsic-posterior job for the GW170817 benchmark. +# Arg $1 = job index (Condor $(Process)); writes cip_rf_.xml.gz in $OUTDIR. +# +# No --seed in CIP -> each launch uses an independent random stream, so running +# this N times accumulates independent statistics for the benchmark posterior. +# +# Env (override via the submit file's `environment`): +# RIFT_CODE path to a working MonteCarloMarginalizeCode/Code (for bin/ + RIFT) +# PYTHON python with RIFT + sklearn + lal (e.g. the gwkokab env) +# NET the ILE .net input +# OUTDIR where to write outputs +set -euo pipefail +IDX="${1:-0}" +RIFT_CODE="${RIFT_CODE:-/home/oshaughn/research-projects-RIT/MonteCarloMarginalizeCode/Code}" +PYTHON="${PYTHON:-/home/oshaughn/.conda/envs/gwkokab/bin/python}" +NET="${NET:-/home/oshaughn/all.net}" +OUTDIR="${OUTDIR:-$(pwd)}" + +export PYTHONPATH="${RIFT_CODE}:${PYTHONPATH:-}" +# Keep each job's CPU/memory footprint small and predictable across 10 parallel jobs. +export OMP_NUM_THREADS=2 OPENBLAS_NUM_THREADS=2 MKL_NUM_THREADS=2 NUMEXPR_NUM_THREADS=2 + +mkdir -p "${OUTDIR}" +cd "${OUTDIR}" + +exec "${PYTHON}" "${RIFT_CODE}/bin/util_ConstructIntrinsicPosterior_GenericCoordinates.py" \ + --fname "${NET}" --fit-method rf --sampler-method AV \ + --parameter delta_mc --parameter-implied mu1 --parameter-implied mu2 \ + --parameter-implied LambdaTilde --parameter-implied DeltaLambdaTilde \ + --parameter-nofit mc --parameter-nofit s1z --parameter-nofit s2z \ + --parameter-nofit lambda1 --parameter-nofit lambda2 \ + --mc-range '[1.196,1.199]' --chi-max 0.05 --input-tides --cap-points 30000 \ + --n-eff 3000 --n-output-samples 5000 --no-plots \ + --fname-output-samples "cip_rf_${IDX}" \ + --fname-output-integral "cip_rf_int_${IDX}" diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/compare.py b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/compare.py new file mode 100644 index 000000000..5cf6839fe --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/compare.py @@ -0,0 +1,106 @@ +""" +Compare intrinsic-posterior 1D marginals via Jensen-Shannon divergence. + +The success metric for jax_cip is NOT sampler ESS -- it is how close the recovered +posterior is to the production CIP+RF benchmark. We quantify that with the JS +divergence of each 1D marginal (start with mc, the best-measured parameter), in bits +(log base 2, so JS in [0, 1]). + +Usage: + python -m RIFT.interpolators.jax_gp.applications.compare \ + --a jax_out.xml.gz --b cip_rf_out.xml.gz --param mc + +Both inputs are RIFT ChooseWaveformParams XML (what CIP and jax_cip both write). + +Caveat (per design notes): a reliable JS needs enough *independent* samples in BOTH +posteriors. jax_cip's per-run effective sample count is modest, so accumulate over +many seeds/instances (Condor) before trusting a small JS. ``js_divergence_1d`` here +also returns a bootstrap stderr so you can see when you are statistics-limited. +""" +from __future__ import annotations + +import argparse + +import numpy as np + +import RIFT.interpolators.jax_gp # noqa: F401 (enable float64 / consistent env) + + +def js_divergence_1d(a, b, bins=80, value_range=None): + """Jensen-Shannon divergence (bits) between two 1D sample sets via histograms.""" + a = np.asarray(a, float); b = np.asarray(b, float) + if value_range is None: + lo = min(a.min(), b.min()); hi = max(a.max(), b.max()) + if hi <= lo: + hi = lo + 1e-12 + value_range = (lo, hi) + edges = np.linspace(value_range[0], value_range[1], bins + 1) + pa, _ = np.histogram(a, bins=edges, density=True) + pb, _ = np.histogram(b, bins=edges, density=True) + w = np.diff(edges) + pa = pa * w; pb = pb * w # -> probabilities per bin + pa = pa / pa.sum(); pb = pb / pb.sum() + m = 0.5 * (pa + pb) + + def _kl(p, q): + mask = p > 0 + return np.sum(p[mask] * np.log2(p[mask] / q[mask])) + + return float(0.5 * _kl(pa, m) + 0.5 * _kl(pb, m)) + + +def js_with_stderr(a, b, bins=80, n_boot=200, seed=0): + """JS plus a bootstrap stderr (so you can tell when you're statistics-limited).""" + a = np.asarray(a, float); b = np.asarray(b, float) + lo = min(a.min(), b.min()); hi = max(a.max(), b.max()) + base = js_divergence_1d(a, b, bins=bins, value_range=(lo, hi)) + rng = np.random.default_rng(seed) + boots = [js_divergence_1d(rng.choice(a, len(a)), rng.choice(b, len(b)), + bins=bins, value_range=(lo, hi)) for _ in range(n_boot)] + return base, float(np.std(boots)) + + +def load_param_from_xml(fname, param="mc"): + """Load a 1D parameter array (default mc, in Msun) from a RIFT samples XML.""" + import lal + import RIFT.lalsimutils as lalsimutils + P_list = lalsimutils.xml_to_ChooseWaveformParams_array(fname) + vals = [] + for P in P_list: + if param == "mc": + vals.append(lalsimutils.mchirp(P.m1, P.m2) / lal.MSUN_SI) + elif param in ("m1", "m2"): + vals.append(getattr(P, param) / lal.MSUN_SI) + else: + vals.append(P.extract_param(param)) + return np.asarray(vals, float) + + +def load_param_pooled(fnames, param="mc"): + """Concatenate a parameter across several XMLs (pool e.g. 10 benchmark runs).""" + import glob + files = [] + for f in fnames: + files.extend(sorted(glob.glob(f)) or [f]) + return np.concatenate([load_param_from_xml(f, param) for f in files]) + + +def main(argv=None): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--a", required=True, nargs="+", + help="samples XML(s) A (e.g. jax_cip output; globs ok)") + p.add_argument("--b", required=True, nargs="+", + help="samples XML(s) B (e.g. the 10 CIP+RF benchmark runs; globs ok)") + p.add_argument("--param", default="mc") + p.add_argument("--bins", type=int, default=80) + a = p.parse_args(argv) + va = load_param_pooled(a.a, a.param) + vb = load_param_pooled(a.b, a.param) + js, se = js_with_stderr(va, vb, bins=a.bins) + print("param={} : A n={} mean={:.6g} B n={} mean={:.6g}".format( + a.param, len(va), va.mean(), len(vb), vb.mean())) + print("JS(A,B) = {:.4f} +/- {:.4f} bits".format(js, se)) + + +if __name__ == "__main__": + main() diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/diff_sampler.py b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/diff_sampler.py new file mode 100644 index 000000000..94b074aec --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/diff_sampler.py @@ -0,0 +1,489 @@ +""" +Differentiable sampling of a fitted jax_gp likelihood surrogate. + +The thesis (see ``../DESIGN.md``): once the GP gives us a *differentiable* +``lnL(theta)``, we can sample the (unnormalized) posterior with a +**gradient-based** sampler (HMC / NUTS) instead of brute-force Monte Carlo. +This pays off enormously when the posterior is **sharp** (high SNR), where a +random-walk explorer wastes almost every proposal but a gradient-informed +sampler walks straight to and around the peak. + +This module provides + +* :func:`sample_nuts` -- treat ``model.lnL_physical(theta)`` as an unnormalized + log-density, put a broad Normal prior around the fit centre, and run numpyro + NUTS. Returns samples, gradient-evaluation count, ESS and wall-clock. +* :func:`sample_rwm` -- a gradient-free random-walk Metropolis baseline that + spends the *same number of lnL evaluations*, for an apples-to-apples + efficiency comparison. +* :func:`sample_flowMC` -- best-effort flowMC wrapper (skips gracefully if the + flowMC API does not cooperate; NUTS is the required path). +* :func:`demo_synthetic` -- fit an RFF surrogate to a known sharp correlated + Gaussian lnL in d=5, sample it both ways, and report posterior-recovery + accuracy and ESS-per-lnL-evaluation for NUTS vs the gradient-free baseline. + +Everything samples the *fitted surrogate* in *fit coordinates* -- a stand-in for +the eventual CIP integrator swap, not (yet) a real physical-parameter sampler. +See ``SAMPLER_NOTES.md``. +""" +from __future__ import annotations + +import argparse +import time + +import numpy as np + +# Importing the package enables jax float64 *before* numpyro draws any arrays. +import RIFT.interpolators.jax_gp as jax_gp +from RIFT.interpolators.jax_gp import export, get_interpolator + +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC, NUTS +from numpyro.diagnostics import effective_sample_size + + +# --------------------------------------------------------------------------- # +# helpers # +# --------------------------------------------------------------------------- # +def _prior_loc_scale(model, bounds, prior_scale): + """Return ``(loc, scale)`` jnp vectors [d] for the broad Normal prior. + + If ``bounds`` (a ``(lo, hi)`` pair of length-d arrays) is given, the prior + is centred on the box midpoint with a scale spanning it; otherwise it is + built from the fit's ``x_mean`` / ``x_std`` (the region the GP was trained + on, which is exactly the relevant region). + """ + if bounds is not None: + lo, hi = (jnp.asarray(np.asarray(b, dtype=np.float64)) for b in bounds) + loc = 0.5 * (lo + hi) + scale = 0.5 * (hi - lo) + else: + loc = jnp.asarray(model.x_mean) + scale = float(prior_scale) * jnp.asarray(model.x_std) + return loc, scale + + +def _make_numpyro_model(model, loc, scale): + """Build a numpyro model: broad Normal prior + ``factor`` of the GP lnL. + + A broad *Normal* prior (rather than a hard Uniform) keeps NUTS in an + unconstrained space and well-behaved while still localizing the relevant + region; the GP ``factor`` then sculpts the actual posterior. + """ + lnL = model.lnL_physical + + def numpyro_model(): + theta = numpyro.sample("theta", dist.Normal(loc, scale).to_event(1)) + numpyro.factor("lnL", lnL(theta)) + + return numpyro_model + + +# --------------------------------------------------------------------------- # +# 1. NUTS # +# --------------------------------------------------------------------------- # +def sample_nuts(model, bounds=None, num_warmup=500, num_samples=2000, + prior_scale=3.0, seed=0): + """Sample ``model.lnL_physical`` as a log-density with numpyro NUTS. + + Parameters + ---------- + model : interpolator + A fitted (or exported/loaded) jax_gp model exposing ``lnL_physical``, + ``x_mean`` and ``x_std``. + bounds : tuple(array, array), optional + ``(lo, hi)`` length-d arrays. If given, the broad prior spans this box; + otherwise it is ``Normal(x_mean, prior_scale * x_std)``. + num_warmup, num_samples : int + NUTS warmup / sampling iterations. + prior_scale : float + Prior width in units of ``x_std`` (only used when ``bounds`` is None). + seed : int + PRNG seed. + + Returns + ------- + dict + ``samples`` [num_samples, d], ``n_grad_evals`` (leapfrog steps, the + gradient-evaluation count), ``ess`` [d], ``ess_min`` (float), + ``wall_clock`` (s), ``mean`` [d], ``cov`` [d, d]. + """ + loc, scale = _prior_loc_scale(model, bounds, prior_scale) + numpyro_model = _make_numpyro_model(model, loc, scale) + + kernel = NUTS(numpyro_model) + mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, + progress_bar=False) + rng = jax.random.PRNGKey(seed) + + t0 = time.time() + # ``num_steps`` per iteration = leapfrog steps = gradient evaluations. + mcmc.run(rng, extra_fields=("num_steps",)) + samples = np.asarray(mcmc.get_samples()["theta"]) + wall = time.time() - t0 + + # Gradient-eval count = total leapfrog steps (each step = one grad of lnL). + extra = mcmc.get_extra_fields(group_by_chain=False) + if "num_steps" in extra: + n_grad = int(np.sum(np.asarray(extra["num_steps"]))) + else: + n_grad = num_samples # conservative placeholder if field unavailable + ess = np.asarray(effective_sample_size(samples[None, ...])) + + return { + "samples": samples, + "n_grad_evals": n_grad, + "ess": ess, + "ess_min": float(np.min(ess)), + "wall_clock": wall, + "mean": samples.mean(axis=0), + "cov": np.cov(samples, rowvar=False), + "n_lnL_evals": n_grad, # NUTS: one lnL+grad per leapfrog step + } + + +# --------------------------------------------------------------------------- # +# 2. gradient-free baseline: random-walk Metropolis # +# --------------------------------------------------------------------------- # +def sample_rwm(model, bounds=None, n_evals=None, prior_scale=3.0, + step_scale=0.5, seed=0): + """Gradient-free random-walk Metropolis on the same log-density. + + Uses the *same* ``lnL_physical + broad-Normal-prior`` target as + :func:`sample_nuts` and is budgeted to spend ``n_evals`` likelihood + evaluations (one per proposal), so ESS/eval is directly comparable. + + Parameters + ---------- + n_evals : int + Number of lnL evaluations (= proposals) to spend. + step_scale : float + Proposal std in units of the prior ``scale``. Tuned loosely toward a + reasonable acceptance rate for the sharp-posterior regime. + + Returns + ------- + dict + ``samples``, ``n_lnL_evals``, ``ess``, ``ess_min``, ``wall_clock``, + ``mean``, ``cov``, ``accept_rate``. + """ + loc, scale = _prior_loc_scale(model, bounds, prior_scale) + loc = np.asarray(loc) + scale = np.asarray(scale) + d = loc.shape[0] + if n_evals is None: + n_evals = 4000 + + lnL_fn = jax.jit(model.lnL_physical) + + def log_target(theta): + # broad Normal prior (matching the numpyro model) + GP factor + lp = -0.5 * np.sum(((theta - loc) / scale) ** 2) + return float(lnL_fn(jnp.asarray(theta))) + lp + + rng = np.random.default_rng(seed) + cur = loc.copy() + cur_lp = log_target(cur) + prop_std = step_scale * scale + + samples = np.empty((n_evals, d)) + n_accept = 0 + t0 = time.time() + for i in range(n_evals): + prop = cur + prop_std * rng.standard_normal(d) + prop_lp = log_target(prop) + if np.log(rng.random()) < (prop_lp - cur_lp): + cur, cur_lp = prop, prop_lp + n_accept += 1 + samples[i] = cur + wall = time.time() - t0 + + # discard first 25% as burn-in for ESS / posterior estimates + burn = n_evals // 4 + post = samples[burn:] + ess = np.asarray(effective_sample_size(post[None, ...])) + return { + "samples": post, + "n_lnL_evals": n_evals, + "ess": ess, + "ess_min": float(np.min(ess)), + "wall_clock": wall, + "mean": post.mean(axis=0), + "cov": np.cov(post, rowvar=False), + "accept_rate": n_accept / n_evals, + } + + +# --------------------------------------------------------------------------- # +# 3. flowMC (best-effort) # +# --------------------------------------------------------------------------- # +def sample_flowMC(model, bounds=None, num_samples=2000, prior_scale=3.0, + seed=0): + """Best-effort flowMC sampler using the same gradient. + + flowMC's public API has shifted across releases; this wrapper attempts a + minimal MALA + normalizing-flow run and **skips gracefully** (returning + ``None`` with a logged note) if anything in its constructor signature does + not line up. NUTS is the supported path; this is a bonus. + """ + try: + loc, scale = _prior_loc_scale(model, bounds, prior_scale) + loc_j = jnp.asarray(loc) + scale_j = jnp.asarray(scale) + d = int(loc_j.shape[0]) + + def log_target(theta, data=None): + lp = -0.5 * jnp.sum(((theta - loc_j) / scale_j) ** 2) + return model.lnL_physical(theta) + lp + + from flowMC.Sampler import Sampler + from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle + + rng = jax.random.PRNGKey(seed) + n_chains = 20 + rng, sub = jax.random.split(rng) + init = loc_j + scale_j * jax.random.normal(sub, (n_chains, d)) + + # flowMC >= 0.4 bundle-style API. + rng, kb = jax.random.split(rng) + bundle = RQSpline_MALA_Bundle( + rng_key=kb, + n_chains=n_chains, + n_dims=d, + logpdf=log_target, + n_local_steps=50, + n_global_steps=50, + n_training_loops=3, + n_production_loops=3, + n_epochs=5, + ) + sampler = Sampler(d, n_chains, rng, + resource_strategy_bundles=bundle) + t0 = time.time() + sampler.sample(init, {}) + wall = time.time() - t0 + prod = sampler.resources["positions_production"] + chains = np.asarray(getattr(prod, "data", prod)) + samples = chains.reshape(-1, d) + return { + "samples": samples, + "wall_clock": wall, + "mean": samples.mean(axis=0), + "cov": np.cov(samples, rowvar=False), + } + except Exception as exc: # noqa: BLE001 -- best-effort by design + print("[sample_flowMC] skipped (best-effort): {}: {}".format( + type(exc).__name__, exc)) + return None + + +# --------------------------------------------------------------------------- # +# synthetic ground truth # +# --------------------------------------------------------------------------- # +def _make_sharp_gaussian(d=5, seed=0): + """Construct a known sharp correlated-Gaussian lnL in ``d`` dims. + + Returns ``(lnL_fn, mu, Sigma)`` where ``lnL_fn(x)`` = log N(x; mu, Sigma) up + to a constant, ``mu`` is the true mean and ``Sigma`` the true covariance. + Because the posterior we sample is ``lnL + broad-Normal-prior``, the + analytic posterior is a Gaussian we can compare against (the prior is so + broad relative to ``Sigma`` that the posterior ~ the lnL Gaussian; we still + compute the exact prior-corrected analytic posterior below). + """ + rng = np.random.default_rng(seed) + mu = rng.uniform(-2.0, 2.0, size=d) + # sharp + correlated: small eigenvalues, random rotation + A = rng.standard_normal((d, d)) + Q, _ = np.linalg.qr(A) + # lengthscales ~0.05..0.2 in each rotated direction => sharp peak + evals = rng.uniform(0.05, 0.2, size=d) ** 2 + Sigma = (Q * evals) @ Q.T + Sigma = 0.5 * (Sigma + Sigma.T) + Prec = np.linalg.inv(Sigma) + + def lnL_fn(X): + X = np.atleast_2d(X) + delta = X - mu + quad = np.einsum("ni,ij,nj->n", delta, Prec, delta) + return -0.5 * quad + + return lnL_fn, mu, Sigma + + +def _analytic_posterior(mu_L, Sigma_L, prior_loc, prior_scale_vec): + """Exact Gaussian posterior of ``N(mu_L, Sigma_L) * N(prior_loc, diag)``.""" + Prec_L = np.linalg.inv(Sigma_L) + Prec_p = np.diag(1.0 / prior_scale_vec ** 2) + Prec_post = Prec_L + Prec_p + Sigma_post = np.linalg.inv(Prec_post) + mu_post = Sigma_post @ (Prec_L @ mu_L + Prec_p @ prior_loc) + return mu_post, Sigma_post + + +# --------------------------------------------------------------------------- # +# demo # +# --------------------------------------------------------------------------- # +def demo_synthetic(d=5, n_train=3000, n_features=512, n_opt_steps=300, + num_warmup=500, num_samples=2000, seed=0): + """Fit an RFF surrogate to a sharp Gaussian lnL and sample it two ways. + + Demonstrates that gradient-based NUTS recovers the known posterior and does + so at far higher ESS-per-lnL-evaluation than a gradient-free random-walk + baseline given the *same* evaluation budget. + """ + print("=" * 72) + print("SYNTHETIC DEMO: sharp correlated Gaussian lnL, d = {}".format(d)) + print("=" * 72) + + # --- ground truth + training data -------------------------------------- + lnL_true, mu_true, Sigma_true = _make_sharp_gaussian(d=d, seed=seed) + rng = np.random.default_rng(seed + 1) + # Sample training points around the peak (covering the sharp region well) + # plus a broader cloud so the RFF sees the falloff. + L = np.linalg.cholesky(Sigma_true) + n_near = int(0.7 * n_train) + n_far = n_train - n_near + X_near = mu_true + (rng.standard_normal((n_near, d)) * 2.0) @ L.T + X_far = mu_true + rng.standard_normal((n_far, d)) * 1.0 + X = np.vstack([X_near, X_far]) + y = lnL_true(X) + # small per-point "MC error" so the heteroscedastic path is exercised + yerr = np.full_like(y, 0.05) + y = y + rng.normal(0.0, 0.05, size=y.shape) + + print("fitting RFF surrogate ({} pts, {} features, {} steps) ...".format( + n_train, n_features, n_opt_steps)) + t0 = time.time() + model = get_interpolator("rff")( + n_features=n_features, n_opt_steps=n_opt_steps, seed=seed + ).fit(X, y, y_errors=yerr) + print(" fit done in {:.1f}s".format(time.time() - t0)) + + # surrogate accuracy at the truth peak / held-out check + Xte = mu_true + (rng.standard_normal((500, d)) * 2.0) @ L.T + pred = model.predict(Xte) + rmse = float(np.sqrt(np.mean((pred - lnL_true(Xte)) ** 2))) + print(" surrogate held-out RMSE (lnL units): {:.3f}".format(rmse)) + + # --- NUTS --------------------------------------------------------------- + print("\n[NUTS] gradient-based sampling ...") + nuts = sample_nuts(model, num_warmup=num_warmup, num_samples=num_samples, + prior_scale=3.0, seed=seed) + print(" wall-clock : {:.2f}s".format(nuts["wall_clock"])) + print(" lnL/grad evals : {}".format(nuts["n_grad_evals"])) + print(" ESS (min/dim) : {:.0f}".format(nuts["ess_min"])) + print(" ESS per eval : {:.4f}".format( + nuts["ess_min"] / max(nuts["n_grad_evals"], 1))) + + # --- gradient-free baseline (same eval budget) -------------------------- + budget = nuts["n_grad_evals"] + print("\n[RWM] gradient-free baseline, same eval budget = {} ...".format(budget)) + rwm = sample_rwm(model, n_evals=budget, prior_scale=3.0, + step_scale=0.4, seed=seed) + print(" wall-clock : {:.2f}s".format(rwm["wall_clock"])) + print(" lnL evals : {}".format(rwm["n_lnL_evals"])) + print(" accept rate : {:.3f}".format(rwm["accept_rate"])) + print(" ESS (min/dim) : {:.0f}".format(rwm["ess_min"])) + print(" ESS per eval : {:.4f}".format( + rwm["ess_min"] / max(rwm["n_lnL_evals"], 1))) + + # --- analytic posterior to compare against ------------------------------ + prior_loc = np.asarray(model.x_mean) + prior_scale_vec = 3.0 * np.asarray(model.x_std) + mu_post, Sigma_post = _analytic_posterior( + mu_true, Sigma_true, prior_loc, prior_scale_vec) + + def _report_recovery(tag, res): + dmu = res["mean"] - mu_post + # whiten the mean error by the posterior std for a scale-free number + sd_post = np.sqrt(np.diag(Sigma_post)) + z = np.abs(dmu) / sd_post + # cov fractional error (Frobenius) + cov_err = (np.linalg.norm(res["cov"] - Sigma_post) + / np.linalg.norm(Sigma_post)) + print(" [{}] mean |z| max={:.2f} mean={:.2f} | cov rel-err={:.2f}".format( + tag, float(z.max()), float(z.mean()), float(cov_err))) + return float(z.max()), float(cov_err) + + print("\n--- posterior recovery vs analytic Gaussian posterior ---") + nuts_z, nuts_cov = _report_recovery("NUTS", nuts) + rwm_z, rwm_cov = _report_recovery("RWM ", rwm) + + # --- headline efficiency ratio ----------------------------------------- + nuts_eff = nuts["ess_min"] / max(nuts["n_grad_evals"], 1) + rwm_eff = rwm["ess_min"] / max(rwm["n_lnL_evals"], 1) + ratio = nuts_eff / max(rwm_eff, 1e-12) + print("\n" + "=" * 72) + print("HEADLINE: NUTS ESS/eval = {:.4f}, RWM ESS/eval = {:.4f}".format( + nuts_eff, rwm_eff)) + print(" NUTS is {:.1f}x more efficient per lnL-eval (sharp posterior)".format( + ratio)) + print("=" * 72) + + # --- optional flowMC ---------------------------------------------------- + fmc = sample_flowMC(model, num_samples=num_samples, seed=seed) + if fmc is not None: + _report_recovery("flowMC", fmc) + + return { + "model": model, + "nuts": nuts, + "rwm": rwm, + "flowMC": fmc, + "mu_post": mu_post, + "Sigma_post": Sigma_post, + "rmse": rmse, + "efficiency_ratio": ratio, + } + + +def demo_artifact(base, num_warmup=500, num_samples=2000, seed=0): + """Load an exported real lnL artifact and sample it with NUTS.""" + if not export.exists(base): + print("[demo_artifact] no artifact at {!r}; skipping.".format(base)) + return None + print("loading exported artifact: {}".format(base)) + model = export.load(base) + print(" coord_names: {}".format(getattr(model, "coord_names", None))) + print(" d = {}".format(np.asarray(model.x_mean).shape[0])) + nuts = sample_nuts(model, num_warmup=num_warmup, num_samples=num_samples, + prior_scale=3.0, seed=seed) + print(" NUTS: {:.2f}s, {} grad-evals, ESS(min)={:.0f}".format( + nuts["wall_clock"], nuts["n_grad_evals"], nuts["ess_min"])) + print(" posterior mean ({}):".format(getattr(model, "coord_names", "dims"))) + print(" ", np.array2string(nuts["mean"], precision=4)) + return {"model": model, "nuts": nuts} + + +# --------------------------------------------------------------------------- # +# main # +# --------------------------------------------------------------------------- # +def main(argv=None): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--demo", choices=["synthetic"], default="synthetic", + help="which built-in demo to run (default: synthetic)") + p.add_argument("--artifact", default=None, + help="base path of an exported lnL bundle to sample " + "(e.g. /tmp/gw170817_rff); skipped if absent") + p.add_argument("--dim", type=int, default=5) + p.add_argument("--n-train", type=int, default=3000) + p.add_argument("--num-warmup", type=int, default=500) + p.add_argument("--num-samples", type=int, default=2000) + p.add_argument("--seed", type=int, default=0) + args = p.parse_args(argv) + + if args.demo == "synthetic": + demo_synthetic(d=args.dim, n_train=args.n_train, + num_warmup=args.num_warmup, + num_samples=args.num_samples, seed=args.seed) + + if args.artifact is not None: + demo_artifact(args.artifact, num_warmup=args.num_warmup, + num_samples=args.num_samples, seed=args.seed) + + +if __name__ == "__main__": + main() diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/export_artifact.py b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/export_artifact.py new file mode 100644 index 000000000..a90146b8f --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/export_artifact.py @@ -0,0 +1,236 @@ +""" +Package a RIFT ILE ``.net`` file into a self-contained, differentiable lnL artifact. + +This is the "ship it" end of the jax_gp pipeline: it takes raw ILE output (the +per-point Monte-Carlo lnL evaluations CIP normally consumes) and produces a small, +portable bundle (``.npz`` + ``.meta.json``) that reconstructs a pure-JAX, +``jax.grad``-able ``lnL(theta)`` -- with no RIFT/lalsimutils dependency at load time. + +The surrogate is fit in *fit coordinates* (the decorrelated space CIP itself fits in), +so the exported lnL is differentiable in those coordinates -- recorded as +``coord_names`` in the meta -- not in the raw physical parameters. See ``ARTIFACT.md``. + +Example +------- +:: + + python export_artifact.py --net /home/oshaughn/all.net \\ + --out /tmp/gw170817_rff --coords bns +""" +from __future__ import annotations + +import argparse +import json + +import numpy as np + +from RIFT.interpolators.jax_gp import get_interpolator, export +from RIFT.interpolators.jax_gp.benchmark.datasets import ( + BNS_FIT_COORDS, + load_ile_net, + mc_delta_from_m1m2, + to_fit_coordinates, +) + +#: raw 6-parameter intrinsic columns as loaded from an ILE ``.net`` file +RAW_COORDS = ("m1", "m2", "s1z", "s2z", "lambda1", "lambda2") + + +def _build_fit_coordinates(X6, coords): + """Map the raw 6-column ILE intrinsic block ``X6`` to fit coordinates. + + Parameters + ---------- + X6 : ndarray [n, 6] + Columns ``(m1, m2, s1z, s2z, lambda1, lambda2)`` as returned by + :func:`load_ile_net`. + coords : {"bns", "raw"} + ``"bns"`` applies the decorrelated BNS transform + (:data:`~RIFT.interpolators.jax_gp.benchmark.datasets.BNS_FIT_COORDS`); + ``"raw"`` keeps the 6 raw physical parameters. + + Returns + ------- + X_fit : ndarray [n, d] + coord_names : list[str] + """ + if coords == "raw": + return np.asarray(X6, dtype=np.float64), list(RAW_COORDS) + if coords == "bns": + m1, m2, s1z, s2z, l1, l2 = X6.T + mc, dmc = mc_delta_from_m1m2(m1, m2) + X_low = np.column_stack([mc, dmc, s1z, s2z, l1, l2]) + low_names = ["mc", "delta_mc", "s1z", "s2z", "lambda1", "lambda2"] + X_fit = to_fit_coordinates(X_low, low_names, BNS_FIT_COORDS) + return np.asarray(X_fit, dtype=np.float64), list(BNS_FIT_COORDS) + raise ValueError("coords must be 'bns' or 'raw', got {!r}".format(coords)) + + +def build_artifact(net_path, out_base, coords="bns", method="rff", + sigma_cut=0.6, lnL_offset=40.0, cap_points=8000, + n_features=512, n_opt_steps=300, seed=0, + quadgp_residual="svgp"): + """Build and persist a differentiable lnL artifact from an ILE ``.net`` file. + + The pipeline mirrors CIP's fit preparation: load + de-dupe + ``sigma_cut`` the + ILE points, transform to fit coordinates, keep only the high-lnL region + (``lnL > max - lnL_offset``), optionally subsample to ``cap_points``, then fit + the chosen interpolator *using the per-point Monte-Carlo errors*. The model is + exported via :func:`RIFT.interpolators.jax_gp.export.save`, reloaded, and checked + for predict agreement and finite ``jax.grad``. + + Parameters + ---------- + net_path : str + Path to the RIFT ILE ``.net`` file. + out_base : str + Output base path; writes ``.npz`` and ``.meta.json``. + coords : {"bns", "raw"}, optional + Fit-coordinate system (see :func:`_build_fit_coordinates`). + method : str, optional + jax_gp interpolator name (``"rff"``, ``"exact"``, ``"svgp"``, ``"quadgp"``). + ``"quadgp"`` is the PE-grade quadratic-core + GP-residual surrogate; its + residual backend is selected by ``quadgp_residual``. + quadgp_residual : str, optional + Residual GP for ``method="quadgp"`` (``"svgp"``, ``"exact"``, ``"rff"``); + ignored otherwise. + sigma_cut : float, optional + Drop ILE points whose reported ``sigma_lnL`` exceeds this (CIP default 0.6). + lnL_offset : float, optional + Keep only points with ``lnL > max(lnL) - lnL_offset``. + cap_points : int or None, optional + If set and there are more surviving points, random-subsample down to this + many (like CIP ``--cap-points``). + n_features, n_opt_steps, seed : optional + Passed to the interpolator constructor (where applicable). + + Returns + ------- + dict + Metadata: ``n_train``, ``coord_names``, ``lnL_max``, ``holdout_rmse``, + plus the resolved build settings. + """ + rng = np.random.default_rng(seed) + + # 1. load ILE points (with per-point MC errors, sigma-cut + dedupe) + X6, y, yerr, _ = load_ile_net( + net_path, sigma_cut=sigma_cut, return_errors=True) + + # 2. transform to fit coordinates + X_fit, coord_names = _build_fit_coordinates(X6, coords) + + # drop any rows the coordinate transform made non-finite + ok = np.all(np.isfinite(X_fit), axis=1) & np.isfinite(y) & np.isfinite(yerr) + X_fit, y, yerr = X_fit[ok], y[ok], yerr[ok] + + # 3. lnL peak cut: keep the informative high-likelihood region + lnL_max = float(np.max(y)) + keep = y > lnL_max - lnL_offset + X_fit, y, yerr = X_fit[keep], y[keep], yerr[keep] + + # 4. optional random subsample to bound fit cost + if cap_points is not None and len(y) > cap_points: + sel = rng.choice(len(y), size=cap_points, replace=False) + X_fit, y, yerr = X_fit[sel], y[sel], yerr[sel] + + # 5. 15% holdout for an honest generalization estimate + n = len(y) + perm = rng.permutation(n) + n_hold = max(1, int(round(0.15 * n))) + hold_idx, train_idx = perm[:n_hold], perm[n_hold:] + Xtr, ytr, etr = X_fit[train_idx], y[train_idx], yerr[train_idx] + Xho, yho = X_fit[hold_idx], y[hold_idx] + + # 6. fit the chosen interpolator WITH the per-point MC errors + cls = get_interpolator(method) + kwargs = {} + for k, v in (("n_features", n_features), ("n_opt_steps", n_opt_steps), + ("seed", seed)): + if k in cls.__init__.__code__.co_varnames: + kwargs[k] = v + if method in ("quadgp", "quad", "quad-gp"): + # quadgp forwards unknown kwargs to its residual GP via **gp_kwargs, so we + # set the residual backend and pass the RFF feature count through to it. + kwargs["gp_method"] = quadgp_residual + if quadgp_residual == "rff": + kwargs["n_features"] = n_features + model = cls(**kwargs).fit(Xtr, ytr, y_errors=etr) + + # 7. export + reload, and verify the round-trip is faithful + differentiable + export.save(model, out_base, coord_names=coord_names) + reloaded = export.load(out_base) + + p_orig = model.predict(Xho) + p_reload = reloaded.predict(Xho) + if not np.allclose(p_orig, p_reload, rtol=1e-5, atol=1e-5): + raise AssertionError("reloaded predict() disagrees with original model") + + import jax + import jax.numpy as jnp + theta0 = jnp.asarray(Xtr[0], dtype=jnp.float64) + grad = np.asarray(jax.grad(reloaded.lnL_physical)(theta0)) + if not np.all(np.isfinite(grad)): + raise AssertionError("jax.grad of reloaded lnL_physical is not finite") + + # held-out RMSE (on the reloaded model, the thing users actually get) + holdout_rmse = float(np.sqrt(np.mean((p_reload - yho) ** 2))) + + return { + "net_path": net_path, + "out_base": out_base, + "coords": coords, + "method": method, + "quadgp_residual": quadgp_residual if method in ("quadgp", "quad", "quad-gp") else None, + "coord_names": coord_names, + "n_train": int(len(ytr)), + "n_holdout": int(len(yho)), + "lnL_max": lnL_max, + "holdout_rmse": holdout_rmse, + "grad_finite": True, + "sigma_cut": sigma_cut, + "lnL_offset": lnL_offset, + "cap_points": cap_points, + "n_features": n_features, + "n_opt_steps": n_opt_steps, + } + + +def main(argv=None): + """argparse CLI: build an artifact and print its metadata as JSON.""" + p = argparse.ArgumentParser( + description="Export a differentiable lnL artifact from a RIFT ILE .net file.") + p.add_argument("--net", required=True, help="path to the ILE .net file") + p.add_argument("--out", required=True, + help="output base path (writes .npz + .meta.json)") + p.add_argument("--coords", choices=("bns", "raw"), default="bns", + help="fit coordinate system (default: bns)") + p.add_argument("--method", default="rff", + help="jax_gp interpolator: rff|exact|svgp|quadgp (default: rff)") + p.add_argument("--quadgp-residual", default="svgp", + choices=("svgp", "exact", "rff"), + help="residual GP backend when --method quadgp (default: svgp)") + p.add_argument("--sigma-cut", type=float, default=0.6, + help="drop ILE points with sigma_lnL above this (default: 0.6)") + p.add_argument("--lnL-offset", type=float, default=40.0, + help="keep lnL > max - lnL_offset (default: 40.0)") + p.add_argument("--cap-points", type=int, default=8000, + help="random-subsample to at most this many points (default: 8000)") + p.add_argument("--n-features", type=int, default=512, + help="number of random Fourier features (RFF) (default: 512)") + p.add_argument("--n-opt-steps", type=int, default=300, + help="optimizer steps for the fit (default: 300)") + p.add_argument("--seed", type=int, default=0, help="RNG seed (default: 0)") + args = p.parse_args(argv) + + meta = build_artifact( + net_path=args.net, out_base=args.out, coords=args.coords, + method=args.method, sigma_cut=args.sigma_cut, lnL_offset=args.lnL_offset, + cap_points=args.cap_points, n_features=args.n_features, + n_opt_steps=args.n_opt_steps, seed=args.seed, + quadgp_residual=args.quadgp_residual) + print(json.dumps(meta, indent=2)) + return meta + + +if __name__ == "__main__": + main() diff --git a/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/export_at_scale.py b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/export_at_scale.py new file mode 100644 index 000000000..5694db354 --- /dev/null +++ b/MonteCarloMarginalizeCode/Code/RIFT/interpolators/jax_gp/applications/export_at_scale.py @@ -0,0 +1,1341 @@ +""" +export_at_scale -- point at a real RIFT run directory, ship a differentiable lnL +artifact for its ``all.net``, and validate it against the run's own posterior. + +This is the "do it for real, do it for many" wrapper around the jax_gp primitives. +``export_artifact``/``jax_cip`` already turn a *single* ``all.net`` into a pure-JAX +``lnL(theta)`` surrogate; this tool adds the three things a production sweep needs: + + 1. **Discovery.** Read a run directory the way RIFT left it -- detect the + ``all.net`` column layout (tides / no-tides / precessing all differ), parse the + active ``args_cip_list.txt`` for the fit coordinates and the prior box + (``--mc-range`` / ``--eta-range`` / ``--chi-max``), and locate the run's own + final CIP posterior (``posterior_samples-.dat``) to validate against. + + 2. **Export + validate the *deliverable*.** Fit the surrogate in dimension-agnostic + physical fit coordinates -- ``[mc, delta_mc]`` plus whichever spin/tidal columns + actually vary, so it covers aligned, BNS *and* precessing runs without needing a + hand-written coordinate transform -- export it, **reload the saved artifact**, + draw a posterior *from the reloaded bytes* (importance sampling over the run's + prior box), and report the Jensen-Shannon divergence of the ``mc``/``q``/ + ``chi_eff`` marginals against the run's CIP posterior. + + 3. **Scale.** ``batch`` discovers many run directories and either runs them locally + or emits an HTCondor DAG -- one node per run, the submit file templated from the + run's *own* ``CIP.sub`` (accounting group, singularity image, requirements) so + the jobs land in the same place the run itself ran. + +Nothing is written back into the run directory: every artifact + report goes under a +separate ``--workroot`` (default ``./export_at_scale_out``), one subdirectory per run. + +At this stage we interpolate **only ``all.net``** (the existing intrinsic ILE +deliverable). The distance-grid export is a separate track. + +CLI:: + + # one run, immediately + python -m RIFT.interpolators.jax_gp.applications.export_at_scale one \\ + --run-dir /path/to/rundir --workroot /path/to/out + + # inspect what discovery found, without doing any work + python -m RIFT.interpolators.jax_gp.applications.export_at_scale discover \\ + --run-dir /path/to/rundir + + # many runs -> a condor DAG (submit it with condor_submit_dag) + python -m RIFT.interpolators.jax_gp.applications.export_at_scale batch \\ + --runs '/data/*/S*/rift*/' --workroot /path/to/out --condor +""" +from __future__ import annotations + +import argparse +import ast +import glob +import json +import math +import os +import shlex +import sys +import time + +import numpy as np + + +# --------------------------------------------------------------------------- # +# 1. all.net column-layout detection +# --------------------------------------------------------------------------- # +# +# RIFT's util_CleanILE writes a *variable-width* composite. The intrinsic block +# grows with the physics (aligned -> precessing adds in-plane spins; BNS adds +# tides), and the trailing diagnostic columns are always, in order, +# ``lnL sigma_lnL ntot [neff]``. We therefore key off the *tail*: the intrinsic +# block is everything between the leading index column and those diagnostics. + +#: candidate names for the (1- or 2-mass + spin [+ tidal]) intrinsic block, by size +_INTRINSIC_BY_SIZE = { + 2: ["m1", "m2"], + 4: ["m1", "m2", "s1z", "s2z"], + 8: ["m1", "m2", "s1x", "s1y", "s1z", "s2x", "s2y", "s2z"], + 6: ["m1", "m2", "s1z", "s2z", "lambda1", "lambda2"], + 10: ["m1", "m2", "s1x", "s1y", "s1z", "s2x", "s2y", "s2z", + "lambda1", "lambda2"], +} + + +def detect_net_layout(path, max_probe=200): + """Infer the column map of a RIFT ``all.net`` / ``.composite`` file. + + Returns ``(cols, intrinsic_names, has_neff)`` where ``cols`` is the + name->index dict :func:`load_ile_net` consumes and ``intrinsic_names`` are the + physical columns of the intrinsic block (e.g. ``['m1','m2','s1x',...]``). + + The detection is tail-anchored (``... lnL sigma_lnL ntot [neff]``) and + sanity-checked on the candidate ``sigma_lnL`` column (a small positive MC error), + so it is robust to the aligned/precessing/tidal width differences without a + hard-coded table. + """ + rows = [] + with open(path) as fh: + for line in fh: + line = line.strip() + if not line or line.startswith("#"): + continue + rows.append(line.split()) + if len(rows) >= max_probe: + break + if not rows: + raise ValueError("no data rows in {}".format(path)) + ncols = len(rows[0]) + data = np.array([[float(x) for x in r] for r in rows if len(r) == ncols], + dtype=float) + + def _try(has_neff): + tail = 4 if has_neff else 3 + lnL_i = ncols - tail + sig_i = lnL_i + 1 + if lnL_i < 3: # need indx + at least m1,m2 + return None + n_intr = lnL_i - 1 # columns 1..lnL_i-1 (0 is index) + sig = data[:, sig_i] + # sigma_lnL is a small, finite, non-negative MC error + if not np.all(np.isfinite(sig)) or np.any(sig < 0) or np.median(sig) > 10: + return None + names = _INTRINSIC_BY_SIZE.get(n_intr) + if names is None: + # unknown block size: name masses + generic spin/tidal slots so the + # fit still runs (we just lose the pretty chi_eff label). + names = ["m1", "m2"] + ["x{}".format(k) for k in range(n_intr - 2)] + cols = {"indx": 0} + for j, nm in enumerate(names): + cols[nm] = 1 + j + cols["lnL"] = lnL_i + cols["sigma_lnL"] = sig_i + cols["ntot"] = sig_i + 1 + return cols, names, has_neff + + # Prefer the 4-trailing (with neff) form; fall back to 3-trailing. + for has_neff in (True, False): + got = _try(has_neff) + if got is not None: + return got + raise ValueError( + "could not infer column layout of {} (ncols={})".format(path, ncols)) + + +# --------------------------------------------------------------------------- # +# 2. run-directory discovery +# --------------------------------------------------------------------------- # + +def _parse_pair(s): + """'[a,b]' -> [float(a), float(b)] (or None).""" + if not s: + return None + try: + v = ast.literal_eval(s) + return [float(v[0]), float(v[1])] + except Exception: + return None + + +def _last_cip_arg_line(run_dir): + """The active (last non-empty) line of args_cip_list.txt, minus its leading + iteration-label token (e.g. ``3``/``Z``/``1``).""" + p = os.path.join(run_dir, "args_cip_list.txt") + if not os.path.exists(p): + return None + lines = [ln.strip() for ln in open(p) if ln.strip()] + if not lines: + return None + toks = lines[-1].split(None, 1) + return toks[1] if len(toks) == 2 else lines[-1] + + +def _cip_opt(tokens, name, multi=False): + """Pull ``--name VALUE`` (repeatable if ``multi``) out of a token list.""" + out = [] + i = 0 + while i < len(tokens): + if tokens[i] == name: + if i + 1 < len(tokens): + out.append(tokens[i + 1]) + i += 2 + else: + i += 1 + if multi: + return out + return out[-1] if out else None + + +def _find_latest_posterior(run_dir): + """Most-recent CIP intrinsic posterior; falls back to the extrinsic fairdraw. + + Prefers ``posterior_samples-.dat`` (highest N) in the run dir, then in the + highest ``iteration__cip/``, then ``extrinsic_posterior_samples.dat``. + """ + cands = [] + for p in glob.glob(os.path.join(run_dir, "posterior_samples-*.dat")): + m = os.path.basename(p)[len("posterior_samples-"):-len(".dat")] + try: + cands.append((int(m), p)) + except ValueError: + pass + if cands: + return max(cands)[1] + it_dirs = sorted(glob.glob(os.path.join(run_dir, "iteration_*_cip")), + key=lambda d: int(d.split("_")[-2]) if d.split("_")[-2].isdigit() + else -1) + for d in reversed(it_dirs): + sub = glob.glob(os.path.join(d, "posterior_samples-*.dat")) + if sub: + return sorted(sub)[-1] + extr = os.path.join(run_dir, "extrinsic_posterior_samples.dat") + return extr if os.path.exists(extr) else None + + +def _parse_condor_env(run_dir): + """Lift accounting / singularity / requirements from the run's own CIP.sub + (or ILE.sub) so a fan-out job can land where the run itself ran.""" + env = {} + for name in ("CIP.sub", "CIP_0.sub", "ILE.sub"): + p = os.path.join(run_dir, name) + if not os.path.exists(p): + continue + for ln in open(p): + ln = ln.strip() + for key in ("accounting_group", "accounting_group_user", + "request_memory", "request_disk", "request_cpus", + "requirements"): + if ln.lower().startswith(key.lower()) and "=" in ln: + env.setdefault(key, ln.split("=", 1)[1].strip()) + if "SingularityImage" in ln and "=" in ln: + env.setdefault("singularity_image", + ln.split("=", 1)[1].strip().strip('"')) + break + return env + + +def discover_run(run_dir): + """Inspect a RIFT run directory and return a dict describing how to export + + validate it. Raises if there is no usable ``all.net``.""" + run_dir = os.path.abspath(run_dir) + net = os.path.join(run_dir, "all.net") + if not os.path.exists(net) or os.path.getsize(net) == 0: + raise FileNotFoundError("no usable all.net in {}".format(run_dir)) + + cols, intrinsic_names, has_neff = detect_net_layout(net) + + cip_line = _last_cip_arg_line(run_dir) + toks = shlex.split(cip_line) if cip_line else [] + mc_range = _parse_pair(_cip_opt(toks, "--mc-range")) + eta_range = _parse_pair(_cip_opt(toks, "--eta-range")) + chi_max = _cip_opt(toks, "--chi-max") + chi_max = float(chi_max) if chi_max else None + chi_small_max = _cip_opt(toks, "--chi-small-max") + chi_small_max = float(chi_small_max) if chi_small_max else None + aligned_prior = _cip_opt(toks, "--aligned-prior") or "uniform" + precessing = "--use-precessing" in toks + params = (_cip_opt(toks, "--parameter", multi=True) + + _cip_opt(toks, "--parameter-implied", multi=True) + + _cip_opt(toks, "--parameter-nofit", multi=True)) + n_out = _cip_opt(toks, "--n-output-samples") + + # event / pipeline tags for a readable workdir name + parts = run_dir.split(os.sep) + event = next((p for p in reversed(parts) if p.startswith(("S", "G", "GW"))), + parts[-2] if len(parts) > 1 else "event") + tag = parts[-1] + + # distance-export ("dgrid") detection. When the run marginalised the distance grid, + # it leaves per-job *.dgrid files and/or a consolidated all_dgrid.dat: the lnL is + # then a function of (intrinsic, distance). That higher-dim export is a SEPARATE + # track (still under development); here we record its presence so the caller can + # route it. The intrinsic all.net export below is unaffected. + dgrid_consolidated = next( + (p for p in (os.path.join(run_dir, "all_dgrid.dat"), + os.path.join(run_dir, "consolidated_dgrid.dat")) + if os.path.exists(p)), None) + has_dgrid = bool(dgrid_consolidated) or bool( + glob.glob(os.path.join(run_dir, "*.dgrid")) + or glob.glob(os.path.join(run_dir, "iteration_*_ile", "*.dgrid"))) + + return { + "run_dir": run_dir, + "net": net, + "ncols": len(cols) + (1 if has_neff else 0), + "cols": cols, + "intrinsic_names": intrinsic_names, + "has_spins": any(n.startswith("s") for n in intrinsic_names), + "has_tides": any(n.startswith("lambda") for n in intrinsic_names), + "precessing": precessing, + "cip_parameters": params, + "mc_range": mc_range, + "eta_range": eta_range, + "chi_max": chi_max if chi_max is not None else 0.99, + "chi_small_max": (chi_small_max if chi_small_max is not None + else (chi_max if chi_max is not None else 0.99)), + "aligned_prior": aligned_prior, + "n_output_samples": int(n_out) if n_out else 2000, + "posterior": _find_latest_posterior(run_dir), + "condor_env": _parse_condor_env(run_dir), + "event": event, + "tag": tag, + "label": "{}__{}".format(event, tag), + } + + +# --------------------------------------------------------------------------- # +# 3. dimension-agnostic fit coordinates +# --------------------------------------------------------------------------- # +# +# We fit lnL in [mc, delta_mc] + (varying spin/tidal columns). mc/delta_mc remove +# the dominant curved chirp-mass ridge; raw spin components keep the rest fully +# general (aligned -> just s1z,s2z survive as non-constant; precessing -> all six). +# Constant columns (e.g. all spins zero in a no-spin run) are dropped from the fit +# and recorded so we can reconstruct full physical vectors for chi_eff / writing. + +_CONST_TOL = 1e-6 + + +def raw_to_fit(X_raw, raw_names, mass_coord="eta", spin_coord="aligned_eff"): + """``(m1,m2,spin...) -> (mc, , , )`` columns. + + The surrogate must be fit in the coordinates the lnL Fisher is quadratic in *and* + axis-aligned with (the quadratic core whitens per-dimension and the GP residual + uses axis-aligned ARD lengthscales -- neither can represent a sharp ridge along a + diagonal direction). + + ``mass_coord``: ``"eta"`` (default) or ``"delta_mc"``. eta is the Fisher-quadratic + variable; delta_mc hides that curvature near equal mass (it is quadratic in + delta_mc), so eta fixes the mass-ratio (q) marginal. + + ``spin_coord``: ``"aligned_eff"`` (default) replaces the aligned components + ``(s1z, s2z)`` with ``(chi_eff, chiMinus)`` -- the principal axes of the + aligned-spin Fisher. In ``(s1z, s2z)`` the well-measured chi_eff is a *diagonal* + ridge an axis-aligned ARD GP over-smooths (the low-mass aligned-spin failure mode); + rotating to ``(chi_eff, chiMinus)`` makes it axis-aligned and resolvable. The + in-plane components ``(s1x, s1y, s2x, s2y)`` are kept as-is (weakly constrained). + ``"cartesian"`` keeps the raw components. + """ + idx = {n: i for i, n in enumerate(raw_names)} + m1 = X_raw[:, idx["m1"]] + m2 = X_raw[:, idx["m2"]] + mc = (m1 * m2) ** 0.6 / (m1 + m2) ** 0.2 + if mass_coord == "eta": + second, sname = m1 * m2 / (m1 + m2) ** 2, "eta" + else: + second, sname = (m1 - m2) / (m1 + m2), "delta_mc" + cols = [mc, second] + names = ["mc", sname] + spin_names = [n for n in raw_names if n not in ("m1", "m2")] + if spin_coord == "aligned_eff" and "s1z" in idx and "s2z" in idx: + M = m1 + m2 + s1z, s2z = X_raw[:, idx["s1z"]], X_raw[:, idx["s2z"]] + cols += [(m1 * s1z + m2 * s2z) / M, (m1 * s1z - m2 * s2z) / M] + names += ["chi_eff", "chiMinus"] + for n in spin_names: # keep the in-plane components + if n in ("s1z", "s2z"): + continue + cols.append(X_raw[:, idx[n]]); names.append(n) + else: + for n in spin_names: + cols.append(X_raw[:, idx[n]]); names.append(n) + return np.column_stack(cols), names + + +def fit_to_physical(Xfit, fit_names): + """Inverse of the mass part + spin passthrough: fit coords -> physical dict + with arrays m1, m2, q, chi_eff and any spin components present.""" + idx = {n: i for i, n in enumerate(fit_names)} + mc = Xfit[:, idx["mc"]] + dmc = Xfit[:, idx["delta_mc"]] + eta = 0.25 * (1.0 - dmc ** 2) + mtot = mc * eta ** (-3.0 / 5.0) + m1 = 0.5 * mtot * (1.0 + dmc) + m2 = 0.5 * mtot * (1.0 - dmc) + out = {"m1": m1, "m2": m2, "mc": mc, "eta": eta, "q": m2 / m1} + for n in fit_names: + if n not in ("mc", "delta_mc"): + out[n] = Xfit[:, idx[n]] + s1z = out.get("s1z", np.zeros_like(m1)) + s2z = out.get("s2z", np.zeros_like(m1)) + out["chi_eff"] = (m1 * s1z + m2 * s2z) / (m1 + m2) + return out + + +def fit_prior_box(fit_names, spec): + """Per-coordinate uniform-prior box [lo, hi] in fit coords, from the run's CIP + ranges (mc, eta->delta_mc, chi-max per spin component).""" + mc_r = spec["mc_range"] or [0.9, 250.0] + eta_r = spec["eta_range"] or [0.01, 0.2499999] + d_hi = math.sqrt(max(0.0, 1.0 - 4.0 * eta_r[0])) + d_lo = math.sqrt(max(0.0, 1.0 - 4.0 * min(eta_r[1], 0.25))) + chi = spec["chi_max"] + lo, hi = [], [] + for n in fit_names: + if n == "mc": + lo.append(mc_r[0]); hi.append(mc_r[1]) + elif n == "delta_mc": + lo.append(max(0.0, d_lo)); hi.append(min(0.999, d_hi)) + elif n.startswith("s"): # spin component + lo.append(-chi); hi.append(chi) + elif n.startswith("lambda"): + lo.append(0.0); hi.append(5000.0) + else: + lo.append(-chi); hi.append(chi) + return np.array(lo), np.array(hi) + + +# --------------------------------------------------------------------------- # +# 3b. RIFT-prior-correct sampling coordinates (apples-to-apples validation) +# --------------------------------------------------------------------------- # +# +# CIP does NOT sample exp(lnL) flat: it applies its prior_map in the *sampled* +# coordinates. To compare apples-to-apples we reproduce that measure. The trick +# (read straight out of CIP's prior_map) is to sample spins in spherical +# coordinates (chi, cos_theta, phi), where RIFT's default precessing prior is +# *separable and flat* -- chi: uniform magnitude (s_magnitude_uniform_prior=1/R), +# cos_theta: uniform, phi: uniform -- so the only non-constant prior factor left is +# the MASS prior (mc_prior ~ mc; delta_mc_prior ~ eta^-6/5). The Cartesian-spin +# 1/chi^2 "singularity" is just the Jacobian of this flat spherical prior, so +# sampling in (chi,cos_theta,phi) both matches RIFT exactly and avoids the singular +# geometry. Aligned runs sample s1z,s2z in Cartesian with the s_component_zprior +# shape when --aligned-prior alignedspin-zprior was used. + +def build_sampling(spec, fit_names, spin_modes=None): + """Return the RIFT-measure sampling spec for the artifact's ``fit_names``. + + ``spin_modes`` (per body ``{"1": m, "2": m}`` with ``m`` in + ``{"sph", "cart_z", None}``) fixes how each body's spin is *sampled*, derived from + the raw physics (which spin columns vary), NOT from ``fit_names`` -- because the + fit may use ``chi_eff``/``chiMinus`` instead of ``s1z``/``s2z``. When omitted, it + is inferred from ``fit_names`` (back-compat: Cartesian fit coords). + + Returns a dict with ``names``/``lo``/``hi`` (prior box), ``ln_prior(theta)``, + ``to_fit(theta)`` (sampling coords -> the artifact's ``fit_names`` vector), + ``raw_to_sample`` (ILE data -> sampling coords) and ``to_compare`` (samples -> + every physical comparison parameter). + """ + import jax.numpy as jnp + + fit_set = set(fit_names) + eta_r = spec["eta_range"] or [0.01, 0.2499999] + mc_r = spec["mc_range"] or [0.9, 250.0] + d_hi = math.sqrt(max(0.0, 1.0 - 4.0 * eta_r[0])) + d_lo = math.sqrt(max(0.0, 1.0 - 4.0 * min(eta_r[1], 0.25))) + precessing = spec["precessing"] + zprior = spec.get("aligned_prior") == "alignedspin-zprior" + Rbody = {"1": spec["chi_max"], "2": spec["chi_small_max"]} + + # DECOUPLE fit vs sampling coordinate. The artifact may be fit in eta (the + # Fisher-quadratic variable, which the quadratic core can capture), but we always + # SAMPLE in delta_mc: its prior is smooth (eta^-6/5, no equal-mass singularity) + # and its geometry is better-conditioned for NUTS -- exactly RIFT's own choice. + # to_fit() maps the sampled delta_mc to whatever mass coordinate the artifact uses. + mass2_fit = fit_names[1] if len(fit_names) > 1 and fit_names[1] in ("eta", "delta_mc") \ + else "delta_mc" + names = ["mc", "delta_mc"] + lo = [mc_r[0], max(0.0, d_lo)] + hi = [mc_r[1], min(0.999, d_hi)] + # how each body's spin is SAMPLED -- from the raw physics (spin_modes) if given, + # else inferred from the fit Cartesian components (back-compat). + if spin_modes is None: + spin_modes = {} + for b in ("1", "2"): + present = [c for c in ("s%sx" % b, "s%sy" % b, "s%sz" % b) if c in fit_set] + spin_modes[b] = ("sph" if len(present) == 3 + else "cart_z" if present else None) + body_mode = {} + for b in ("1", "2"): + R = Rbody[b] + m = spin_modes.get(b) + if m == "sph": # precessing -> spherical sampling + body_mode[b] = ("sph", ["s%sx" % b, "s%sy" % b, "s%sz" % b]) + names += ["chi%s" % b, "cos_theta%s" % b, "phi%s" % b] + lo += [0.0, -1.0, 0.0]; hi += [R, 1.0, 2.0 * math.pi] + elif m == "cart_z": # aligned -> Cartesian s{b}z + body_mode[b] = ("cart", ["s%sz" % b]) + names.append("s%sz" % b); lo.append(-R); hi.append(R) + else: + body_mode[b] = (None, []) + lo = np.array(lo); hi = np.array(hi) + nidx = {n: i for i, n in enumerate(names)} + + mass_prior = spec.get("_mass_prior", "m1m2") + def ln_prior(theta): + mc = theta[nidx["mc"]] + dmc = theta[nidx["delta_mc"]] + eta = 0.25 * (1.0 - dmc * dmc) + if mass_prior == "flat": + lp = mc * 0.0 + else: + # uniform-in-(m1,m2) in the SAMPLED (delta_mc) coordinate: the + # (1-4eta)^-1/2 factor cancels with the d eta/d delta_mc Jacobian, leaving + # the smooth p(mc, delta_mc) ~ mc * eta^-6/5 (no equal-mass singularity). + lp = jnp.log(mc) - 1.2 * jnp.log(eta) + if zprior: # s_component_zprior on aligned comps + for b in ("1", "2"): + mode, comps = body_mode[b] + if mode == "cart": + R = Rbody[b] + for c in comps: + if c.endswith("z"): + s = theta[nidx[c]] + lp = lp + jnp.log(-jnp.log(jnp.abs(s) / R + 1e-7)) + return lp + + def to_fit(theta): + # map sampled (mc, delta_mc) to the artifact's mass coordinate (eta or delta_mc) + dmc = theta[nidx["delta_mc"]] + mass_val = 0.25 * (1.0 - dmc * dmc) if mass2_fit == "eta" else dmc + vals = {"mc": theta[nidx["mc"]], mass2_fit: mass_val} + for b in ("1", "2"): + mode, comps = body_mode[b] + if mode == "sph": + chi = theta[nidx["chi%s" % b]] + ct = theta[nidx["cos_theta%s" % b]] + ph = theta[nidx["phi%s" % b]] + st = jnp.sqrt(jnp.clip(1.0 - ct * ct, 0.0, 1.0)) + vals["s%sz" % b] = chi * ct + vals["s%sx" % b] = chi * st * jnp.cos(ph) + vals["s%sy" % b] = chi * st * jnp.sin(ph) + elif mode == "cart": + for c in comps: + vals[c] = theta[nidx[c]] + # aligned-spin principal axes, if the artifact was fit in them + if "chi_eff" in fit_set or "chiMinus" in fit_set: + eta_v = 0.25 * (1.0 - dmc * dmc) + mtot = vals["mc"] * eta_v ** (-3.0 / 5.0) + m1 = 0.5 * mtot * (1.0 + dmc); m2 = 0.5 * mtot * (1.0 - dmc) + s1z = vals.get("s1z", 0.0 * dmc); s2z = vals.get("s2z", 0.0 * dmc) + M = m1 + m2 + vals["chi_eff"] = (m1 * s1z + m2 * s2z) / M + vals["chiMinus"] = (m1 * s1z - m2 * s2z) / M + return jnp.stack([vals[n] for n in fit_names]) + + def raw_to_sample(X_raw, raw_names): + ridx = {n: i for i, n in enumerate(raw_names)} + m1 = X_raw[:, ridx["m1"]]; m2 = X_raw[:, ridx["m2"]] + mc = (m1 * m2) ** 0.6 / (m1 + m2) ** 0.2 + dmc = (m1 - m2) / (m1 + m2) # sampling coordinate is delta_mc + cols = [mc, dmc] + for b in ("1", "2"): + mode, comps = body_mode[b] + if mode == "sph": + sx = X_raw[:, ridx["s%sx" % b]]; sy = X_raw[:, ridx["s%sy" % b]] + sz = X_raw[:, ridx["s%sz" % b]] + chi = np.sqrt(sx ** 2 + sy ** 2 + sz ** 2) + ct = np.where(chi > 1e-12, sz / np.clip(chi, 1e-12, None), 0.0) + ph = np.mod(np.arctan2(sy, sx), 2.0 * np.pi) + cols += [chi, ct, ph] + elif mode == "cart": + for c in comps: + cols.append(X_raw[:, ridx[c]]) + return np.column_stack(cols) + + def to_compare(S): + """Samples (in sampling coords) -> every physical comparison parameter: + masses (mc, eta, q), aligned spin combos (chi_eff, chiMinus), and the + cylindrical-polar spin of each body (s{b}z, chi{b}_perp, phi{b}).""" + mc = S[:, nidx["mc"]]; dmc = S[:, nidx["delta_mc"]] + eta = 0.25 * (1.0 - dmc ** 2) + mtot = mc * eta ** (-3.0 / 5.0) + m1 = 0.5 * mtot * (1.0 + dmc); m2 = 0.5 * mtot * (1.0 - dmc) + out = {"mc": mc, "q": m2 / m1, "eta": eta, "m1": m1, "m2": m2} + for b in ("1", "2"): + mode, _ = body_mode[b] + if mode == "sph": + chi = S[:, nidx["chi%s" % b]]; ct = S[:, nidx["cos_theta%s" % b]] + out["s%sz" % b] = chi * ct + out["chi%s_perp" % b] = chi * np.sqrt(np.clip(1.0 - ct ** 2, 0.0, 1.0)) + out["phi%s" % b] = S[:, nidx["phi%s" % b]] + elif mode == "cart" and ("s%sz" % b) in nidx: + out["s%sz" % b] = S[:, nidx["s%sz" % b]] + else: + out["s%sz" % b] = np.zeros_like(m1) + out["chi_eff"] = (m1 * out["s1z"] + m2 * out["s2z"]) / (m1 + m2) + out["chiMinus"] = (m1 * out["s1z"] - m2 * out["s2z"]) / (m1 + m2) + return out + + return {"names": names, "lo": lo, "hi": hi, "ln_prior": ln_prior, + "to_fit": to_fit, "raw_to_sample": raw_to_sample, + "to_compare": to_compare} + + +# --------------------------------------------------------------------------- # +# 4. fit + export the artifact (the deliverable) +# --------------------------------------------------------------------------- # + +def fit_and_export(spec, out_base, method="svgp", sigma_cut=0.6, lnL_offset=40.0, + cap_points=8000, n_features=256, n_opt_steps=300, seed=0, + quadgp_residual="svgp", keep_curv_frac=0.01, + ls_lo_frac=0.2, ls_hi_frac=1.0, mass_coord="eta", + spin_coord="auto"): + """Build, persist and cold-reload-verify a differentiable lnL artifact for the + run's ``all.net``. Returns a metadata dict (also the fit-coord arrays needed by + validation, under private keys). + + ``spin_coord="auto"`` (default) fits BOTH ``aligned_eff`` (chi_eff/chiMinus + principal axes -- fixes the sharp low-mass aligned spin) and ``cartesian`` + (s1z/s2z) and keeps whichever has the lower peak-region holdout RMSE. This is the + "never regress" guard: aligned_eff is a large net win but worse on a minority of + events, and holdout RMSE reliably picks the better of the two (it selects cartesian + exactly where aligned_eff would regress). Pass an explicit ``aligned_eff`` / + ``cartesian`` to skip the selection (half the fit cost).""" + import shutil + if spin_coord == "auto": + cands = [] + for sc in ("aligned_eff", "cartesian"): + m = fit_and_export( + spec, out_base + "__" + sc, method=method, sigma_cut=sigma_cut, + lnL_offset=lnL_offset, cap_points=cap_points, n_features=n_features, + n_opt_steps=n_opt_steps, seed=seed, quadgp_residual=quadgp_residual, + keep_curv_frac=keep_curv_frac, ls_lo_frac=ls_lo_frac, + ls_hi_frac=ls_hi_frac, mass_coord=mass_coord, spin_coord=sc) + cands.append((m["holdout_rmse"], sc, m)) + cands.sort(key=lambda t: t[0]) + best = cands[0][2] + for ext in (".npz", ".meta.json"): # promote the winner to out_base + shutil.copyfile(best["out_base"] + ext, out_base + ext) + for _, sc, _ in cands: # clean up the candidate exports + for ext in (".npz", ".meta.json"): + try: + os.remove(out_base + "__" + sc + ext) + except OSError: + pass + best["out_base"] = out_base + best["spin_coord"] = cands[0][1] + best["spin_coord_auto"] = {sc: round(r, 3) for r, sc, _ in cands} + return best + + from RIFT.interpolators.jax_gp import get_interpolator, export + from RIFT.interpolators.jax_gp.benchmark.datasets import load_ile_net + from RIFT.interpolators.jax_gp.applications.jax_cip import _tree_ring_select + + # 1. load the intrinsic block with the *detected* layout, sigma-cut + dedupe + X_raw, y, yerr, _ = load_ile_net( + spec["net"], fit_params=tuple(spec["intrinsic_names"]), + cols=spec["cols"], sigma_cut=sigma_cut, return_errors=True) + + # how each body's spin is SAMPLED is set by the raw physics (which spin columns + # vary), independent of the fit representation (Cartesian vs chi_eff/chiMinus). + raw_names = list(spec["intrinsic_names"]) + ridx = {n: i for i, n in enumerate(raw_names)} + spin_modes = {} + for b in ("1", "2"): + comps = [c for c in ("s%sx" % b, "s%sy" % b, "s%sz" % b) if c in ridx] + varies = [c for c in comps if X_raw[:, ridx[c]].std() > _CONST_TOL] + inplane = any(c.endswith(("x", "y")) for c in varies) + spin_modes[b] = ("sph" if inplane else "cart_z" if varies else None) + + # 2. physical fit coordinates; drop columns that do not vary (record them) + Xfit_all, fit_names_all = raw_to_fit(X_raw, raw_names, + mass_coord=mass_coord, spin_coord=spin_coord) + spread = Xfit_all.std(axis=0) + keep = spread > _CONST_TOL + keep[0] = keep[1] = True # always keep mc, delta_mc + fit_names = [n for n, k in zip(fit_names_all, keep) if k] + constants = {n: float(Xfit_all[:, i].mean()) + for i, (n, k) in enumerate(zip(fit_names_all, keep)) if not k} + Xfit = Xfit_all[:, keep] + + # 3. high-lnL region + stratified ("tree-ring") downselect to bound fit cost + # (carry the raw intrinsic columns through the same masks: validation needs + # them to build the proposal in RIFT's sampling coordinates) + ok = np.all(np.isfinite(Xfit), axis=1) & np.isfinite(y) & np.isfinite(yerr) + Xfit, y, yerr, X_raw = Xfit[ok], y[ok], yerr[ok], X_raw[ok] + lnL_max = float(np.max(y)) + band = y > lnL_max - lnL_offset + Xfit, y, yerr, X_raw = Xfit[band], y[band], yerr[band], X_raw[band] + if cap_points and len(y) > cap_points: + sel = _tree_ring_select(y, cap_points, seed=seed) + Xfit, y, yerr, X_raw = Xfit[sel], y[sel], yerr[sel], X_raw[sel] + + # 4. honest 15% holdout + rng = np.random.default_rng(seed) + n = len(y) + perm = rng.permutation(n) + n_hold = max(1, int(round(0.15 * n))) + ho, tr = perm[:n_hold], perm[n_hold:] + Xtr, ytr, etr, Xho, yho = Xfit[tr], y[tr], yerr[tr], Xfit[ho], y[ho] + + # 5. fit the interpolator with the per-point MC errors + cls = get_interpolator(method) + if method in ("rff", "gp-jax-rff"): + model = cls(n_features=n_features, n_opt_steps=n_opt_steps, seed=seed) + elif method in ("svgp", "gp-jax-svgp"): + model = cls(n_inducing=n_features, n_opt_steps=n_opt_steps, seed=seed) + elif method == "quadgp": + # forward only the kwargs the chosen residual backend accepts (via **gp_kwargs) + if quadgp_residual == "svgp": + gpkw = dict(n_inducing=n_features, seed=seed, + ls_lo_frac=ls_lo_frac, ls_hi_frac=ls_hi_frac) + elif quadgp_residual == "rff": + gpkw = dict(n_features=n_features, seed=seed) + else: # exact: no inducing/seed kwargs + gpkw = {} + model = cls(gp_method=quadgp_residual, n_opt_steps=n_opt_steps, + keep_curv_frac=keep_curv_frac, **gpkw) + else: + model = cls(n_opt_steps=n_opt_steps) + model = model.fit(Xtr, ytr, y_errors=etr) + model.coord_names = list(fit_names) + + # 6. export, reload, and prove the saved bytes are faithful + differentiable + export.save(model, out_base, coord_names=fit_names, + extra_meta={"constants": constants, "event": spec["event"], + "tag": spec["tag"], "net": spec["net"]}) + reloaded = export.load(out_base) + p_reload = reloaded.predict(Xho) + if not np.allclose(model.predict(Xho), p_reload, rtol=1e-5, atol=1e-4): + raise AssertionError("reloaded predict() disagrees with the fitted model") + import jax + import jax.numpy as jnp + g = np.asarray(jax.grad(reloaded.lnL_physical)(jnp.asarray(Xtr[0]))) + if not np.all(np.isfinite(g)): + raise AssertionError("jax.grad of reloaded lnL is not finite") + # Headline holdout RMSE is over the PE-relevant peak region (within 15 nats of the + # peak); the eta quadratic core extrapolates steeply in the deep low-lnL tail + # (~zero posterior weight), which would otherwise dominate a plain RMSE. + rmse_all = float(np.sqrt(np.mean((p_reload - yho) ** 2))) + peak = yho > (lnL_max - 15.0) + holdout_rmse = float(np.sqrt(np.mean((p_reload[peak] - yho[peak]) ** 2))) \ + if peak.any() else rmse_all + + meta = { + "out_base": out_base, "method": method, "coord_names": fit_names, + "constants": constants, "n_train": int(len(ytr)), + "n_holdout": int(len(yho)), "lnL_max": lnL_max, + "holdout_rmse": holdout_rmse, "holdout_rmse_all": rmse_all, + "mass_coord": mass_coord, "spin_coord": spin_coord, + "keep_curv_frac": keep_curv_frac, + "grad_finite": True, "n_intrinsic_dims": len(fit_names), + } + # private handoff to validation (not serialised in the public report verbatim) + meta["_fit_names"] = fit_names + meta["_Xfit"] = Xfit + meta["_y"] = y + meta["_X_raw"] = X_raw + meta["_raw_names"] = list(spec["intrinsic_names"]) + meta["_spin_modes"] = spin_modes + return meta + + +# --------------------------------------------------------------------------- # +# 5. validation: sample the reloaded artifact, compare to the CIP posterior +# --------------------------------------------------------------------------- # + +def _load_posterior_dat(path): + """Load a RIFT ``posterior_samples-*.dat`` / ``extrinsic_posterior_samples.dat`` + into a name->array dict using its ``# ...`` header, and derive the full intrinsic + comparison set (mc, eta, q, chi_eff, chiMinus, and cylindrical-polar spins + s{b}z, chi{b}_perp, phi{b}) so it lines up with :func:`build_sampling`'s + ``to_compare``.""" + header = None + with open(path) as fh: + for ln in fh: + if ln.strip().startswith("#"): + header = ln.lstrip("#").split() + break + data = np.loadtxt(path) + if data.ndim == 1: + data = data[None, :] + if header is None or len(header) != data.shape[1]: + # headerless / mismatched: fall back to the canonical CIP column order + header = ["m1", "m2", "a1x", "a1y", "a1z", "a2x", "a2y", "a2z", + "mc", "eta", "indx", "Npts", "ra", "dec", "tref", "phiorb", + "incl", "psi", "dist", "p", "ps", "lnL", "mtotal", "q"] + header = header[:data.shape[1]] + cols = {n: data[:, i] for i, n in enumerate(header)} + if {"m1", "m2"} <= cols.keys(): + m1, m2 = cols["m1"], cols["m2"] + cols.setdefault("mc", (m1 * m2) ** 0.6 / (m1 + m2) ** 0.2) + cols.setdefault("q", np.minimum(m1, m2) / np.maximum(m1, m2)) + cols.setdefault("eta", m1 * m2 / (m1 + m2) ** 2) + # cylindrical-polar spins from the Cartesian a{b}{x,y,z} columns + for b in ("1", "2"): + ax, ay, az = "a%sx" % b, "a%sy" % b, "a%sz" % b + if {ax, ay, az} <= cols.keys(): + cols.setdefault("s%sz" % b, cols[az]) + cols.setdefault("chi%s_perp" % b, + np.sqrt(cols[ax] ** 2 + cols[ay] ** 2)) + cols.setdefault("phi%s" % b, + np.mod(np.arctan2(cols[ay], cols[ax]), 2 * np.pi)) + if {"a1z", "a2z"} <= cols.keys(): + cols.setdefault("chi_eff", (m1 * cols["a1z"] + m2 * cols["a2z"]) / (m1 + m2)) + cols.setdefault("chiMinus", (m1 * cols["a1z"] - m2 * cols["a2z"]) / (m1 + m2)) + return cols + + +def _sample_flow_v060(target, lo, hi, init_theta=None, n_samples=8000, n_chains=30, + n_train_loops=6, n_prod_loops=2, n_epochs=12, seed=0): + """flowMC (>=0.6.0) normalizing-flow importance sampling on the box [lo,hi]. + + A self-contained replacement for the legacy ``jax_cip.sample_flow_is`` (whose + positional ``Sampler(...)`` call broke under flowMC 0.6.0's keyword-only API). + Trains an RQ-spline flow over a sigmoid-into-box latent, then i.i.d. draws + + importance weights ``exp(target + log_jac - log_q)``. Use in ``rift_ad_export``. + """ + import jax + import jax.numpy as jnp + from flowMC.Sampler import Sampler + from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle + + lo = jnp.asarray(lo); hi = jnp.asarray(hi); span = hi - lo + d = int(lo.shape[0]) + + def theta_of_u(u): + return lo + span * jax.nn.sigmoid(u) + + def log_jac(u): + return jnp.sum(jnp.log(span) + jax.nn.log_sigmoid(u) + jax.nn.log_sigmoid(-u)) + + def u_logpdf(u, data=None): + return target(theta_of_u(u)) + log_jac(u) + + key = jax.random.PRNGKey(seed) + key, kb, ks, ki, kd = jax.random.split(key, 5) + bundle = RQSpline_MALA_Bundle( + rng_key=kb, n_chains=n_chains, n_dims=d, logpdf=u_logpdf, + n_local_steps=50, n_global_steps=50, n_training_loops=n_train_loops, + n_production_loops=n_prod_loops, n_epochs=n_epochs) + sampler = Sampler(n_dim=d, n_chains=n_chains, rng_key=ks, + resource_strategy_bundles=bundle) + if init_theta is not None: + frac = np.clip((np.asarray(init_theta, float) - np.asarray(lo)) + / np.asarray(span), 1e-3, 1 - 1e-3) + u0 = np.log(frac / (1 - frac)) + init = jnp.asarray(u0)[None, :] + 0.3 * jax.random.normal(ki, (n_chains, d)) + else: + init = 0.3 * jax.random.normal(ki, (n_chains, d)) + sampler.sample(init, {}) + + flow = sampler.resources["model"] + u = jnp.asarray(flow.sample(kd, n_samples)) + theta = np.asarray(jax.vmap(theta_of_u)(u)) + log_q = np.asarray(jax.vmap(flow.log_prob)(u)) # log_prob is per-sample in 0.6.0 + log_p = np.asarray(jax.jit(jax.vmap(u_logpdf))(u)) + log_w = np.array(log_p - log_q, dtype=np.float64) + m = np.max(log_w) + logZ = float(m + np.log(np.mean(np.exp(log_w - m)))) + w = np.exp(log_w - m); w = w / w.sum() + ess = float(1.0 / np.sum(w ** 2)) + rng = np.random.default_rng(seed) + idx = rng.choice(n_samples, size=min(8000, n_samples), replace=True, p=w) + samples = theta[idx] + return {"samples": samples, "ess": ess, "ess_frac": ess / n_samples, + "logZ": logZ, "mean": samples.mean(0), "std": samples.std(0)} + + +#: every intrinsic parameter the validation reports a JS on. Masses, the aligned-spin +#: combinations (chi_eff, chiMinus), and the cylindrical-polar spin of each body +#: (aligned component s{b}z, in-plane magnitude chi{b}_perp, azimuth phi{b}). +ALL_COMPARE_PARAMS = ( + "mc", "eta", "q", "chi_eff", "chiMinus", + "s1z", "chi1_perp", "phi1", "s2z", "chi2_perp", "phi2", +) + + +def validate_artifact(spec, fit_meta, out_dir, n_samples=40000, inflate=1.2, + seed=0, sampler="auto", compare_params=ALL_COMPARE_PARAMS): + """Draw a posterior from the *reloaded* artifact and JS-compare its marginals to + the run's CIP posterior. Writes ``posterior_interp.dat`` and returns the report. + + The target is the run's actual posterior ``lnL(theta) + ln prior(theta)`` -- NOT a + flat-prior caricature -- so the comparison is apples-to-apples with CIP. Sampling + is done in RIFT's own coordinates (:func:`build_sampling`): spins in + ``(chi, cos_theta, phi)``, where RIFT's isotropic prior is flat, plus the + non-uniform mass-prior shape (``mc_prior`` ~ mc, ``eta_prior`` ~ eta^-6/5). This + both matches CIP's measure and avoids the Cartesian-spin 1/chi^2 singularity. + + ``sampler``: ``"gaussian"`` -- fast mu-matched importance sampling (great in low + dimension); ``"nuts"`` -- gradient-based NUTS preconditioned with the data + covariance, which exploits the artifact's AD gradients and explores the curved, + high-dimensional precessing posterior far better than a single Gaussian proposal; + ``"auto"`` (default) picks ``nuts`` when there are >3 sampling dimensions, else + ``gaussian``. + """ + from RIFT.interpolators.jax_gp import export + from RIFT.interpolators.jax_gp.applications.jax_cip import ( + sample_gaussian_is, sample_nuts_muframe) + from RIFT.interpolators.jax_gp.applications.compare import js_with_stderr + + import jax.numpy as jnp + fit_names = fit_meta["_fit_names"] + X_raw, raw_names = fit_meta["_X_raw"], fit_meta["_raw_names"] + y = fit_meta["_y"] + reloaded = export.load(fit_meta["out_base"]) + + # Sample in RIFT's OWN coordinates + measure (apples-to-apples): spins in + # (chi,cos_theta,phi) where the isotropic prior is flat, plus the mass-prior + # shape. The NUTS/IS target is lnL(theta) + ln prior(theta) -- exactly the CIP + # posterior, not a flat-prior caricature. + smp = build_sampling(spec, fit_names, spin_modes=fit_meta.get("_spin_modes")) + names, lo, hi = smp["names"], smp["lo"], smp["hi"] + Xs = smp["raw_to_sample"](X_raw, raw_names) # ILE data in sampling coords + + def target(theta): + return reloaded.lnL_physical(smp["to_fit"](theta)) + smp["ln_prior"](theta) + + # proposal/preconditioner: lnL-weighted mean+cov of the data in sampling coords, + # restricted to the prior box + inb = np.all((Xs >= lo) & (Xs <= hi), axis=1) + Xp, yp = (Xs[inb], y[inb]) if inb.sum() >= 10 else (Xs, y) + w = np.exp(yp - yp.max()); w /= w.sum() + gmean = (Xp * w[:, None]).sum(0) + gcov = np.atleast_2d(np.cov(Xp.T, aweights=w)) + if gcov.shape[0] == 1: # 1-D: cov() returns a scalar + gcov = gcov.reshape(1, 1) + + if sampler == "auto": + sampler = "nuts" if len(names) > 3 else "gaussian" + if sampler == "flow": + # normalizing-flow IS (flowMC >=0.6.0); needs the rift_ad_export env. + res = _sample_flow_v060(target, lo, hi, init_theta=gmean, + n_samples=max(8000, spec["n_output_samples"]), + seed=seed) + elif sampler == "nuts": + # Gradient-based NUTS, dense mass matrix seeded from the data covariance. + # Not proposal-limited -> explores the curved, weakly-constrained precessing + # directions; uses the artifact's jax.grad lnL (+ prior) directly. + ndraw = max(2000, spec["n_output_samples"]) + res = sample_nuts_muframe(target, gmean, gcov, lo, hi, num_warmup=1000, + num_samples=ndraw, num_chains=2, seed=seed) + else: + res = sample_gaussian_is(target, gmean, gcov, lo, hi, + n_samples=n_samples, inflate=inflate, seed=seed) + samples = res["samples"] # [n, d] in sampling coords + + # spin-magnitude constraint is automatic: chi in [0,R] by the box, and the + # spherical map keeps |spin| = chi <= R. Map to physical comparison params. + phys = smp["to_compare"](samples) + + # write the interpolated posterior (RIFT-ish .dat, every intrinsic column we have) + os.makedirs(out_dir, exist_ok=True) + post_path = os.path.join(out_dir, "posterior_interp.dat") + out_names = [n for n in ("m1", "m2", "mc", "eta", "q", "chi_eff", "chiMinus", + "s1z", "chi1_perp", "phi1", "s2z", "chi2_perp", "phi2") + if n in phys] + np.savetxt(post_path, np.column_stack([phys[n] for n in out_names]), + header=" ".join(out_names)) + + # JS divergence vs the run's own CIP posterior + js = {} + ref = None + if spec["posterior"] and os.path.exists(spec["posterior"]): + ref = _load_posterior_dat(spec["posterior"]) + for prm in compare_params: + if prm in phys and prm in ref and len(phys[prm]) > 20: + a, b = phys[prm], ref[prm] + if np.std(a) < 1e-9 and np.std(b) < 1e-9: + js[prm] = {"js_bits": 0.0, "js_stderr": 0.0, "degenerate": True, + "interp_mean": float(np.mean(a)), "interp_std": 0.0, + "ref_mean": float(np.mean(b)), "ref_std": 0.0, + "n_interp": int(len(a)), "n_ref": int(len(b))} + continue + val, se = js_with_stderr(a, b) + js[prm] = { + "js_bits": val, "js_stderr": se, + "interp_mean": float(np.mean(a)), "ref_mean": float(np.mean(b)), + "interp_std": float(np.std(a)), "ref_std": float(np.std(b)), + "n_interp": int(len(a)), "n_ref": int(len(b)), + } + # Honesty about sampling: a JS computed from few *independent* draws is noisy + # regardless of the (resampled-with-replacement) bootstrap stderr. Flag it. + ess = float(res.get("ess", float("nan"))) + if not np.isfinite(ess) or ess < 100: + quality = "sampling-limited" + elif ess < 400: + quality = "marginal" + else: + quality = "ok" + return { + "posterior_interp": post_path, + "reference_posterior": spec["posterior"], + "n_posterior_samples": int(len(samples)), + "sampler": sampler, + "is_ess": ess, + "is_ess_frac": float(res.get("ess_frac", float("nan"))), + "logZ": res.get("logZ"), + "quality": quality, + "js": js, + } + + +# --------------------------------------------------------------------------- # +# 6. orchestration: one run end-to-end +# --------------------------------------------------------------------------- # + +def run_one(run_dir, workroot, method="quadgp", n_samples=40000, seed=0, + cap_points=8000, n_features=256, n_opt_steps=300, lnL_offset=40.0, + sigma_cut=0.6, sampler="auto", keep_curv_frac=0.01, + ls_lo_frac=0.2, ls_hi_frac=1.0, mass_coord="eta", + spin_coord="auto", write_plot=True): + """Discover -> fit+export -> validate one run; write all artifacts under + ``workroot/