Source code for gwsumm.plot.range

# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm.  If not, see <http://www.gnu.org/licenses/>.

"""Definitions for range plots
"""

from math import pi
from heapq import nlargest
from itertools import combinations

import numpy

from matplotlib.ticker import (LogLocator, MaxNLocator)

import gwpy.astro
from gwpy.segments import (Segment, SegmentList)
from gwpy.timeseries import TimeSeries

from gwdetchar.plot import texify

from .registry import (get_plot, register_plot)
from .utils import hash
from ..data import (get_range_channel, get_range,
                    get_range_spectrogram, get_timeseries)
from ..segments import get_segments
from ..channels import split as split_channels

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'
__credits__ = 'Alex Urban <alexander.urban@ligo.org>'


# -- utils --------------------------------------------------------------------

RAN_DICT = {
        'sensemon_range': gwpy.astro.sensemon_range,
        'sensemon_range_psd': gwpy.astro.sensemon_range_psd,
        'inspiral_range': gwpy.astro.inspiral_range,
        'inspiral_range_psd': gwpy.astro.inspiral_range_psd,
        'burst_range': gwpy.astro.burst_range,
        'burst_range_psd': gwpy.astro.burst_range_spectrum,
}


def _get_params(keys, pargs, nchans=1):
    """Return a `dict` of `list` of plot arguments for every channel
    """
    params = {}
    for key in keys:
        try:
            value = pargs.pop(key)
        except KeyError:
            continue
        if not isinstance(value, (tuple, list)):
            value = [value] * nchans
        params[key] = value
    return params


# -- sensitive range ----------------------------------------------------------

[docs] class RangePlotMixin(object): data = 'spectrogram' _threadsafe = False defaults = { 'snr': 8.0, 'stride': 60., 'fftlength': 8, 'overlap': 4, 'fmin': 10, } def __init__(self, *args, **kwargs): super(RangePlotMixin, self).__init__(*args, **kwargs) self.fftparams = _get_params( ['stride', 'fftlength', 'overlap'], self.pargs, nchans=len(self.channels)) self.rangeparams = _get_params( ['mass1', 'mass2', 'snr', 'energy', 'fmin', 'fmax', 'range_func'], self.pargs, nchans=len(self.channels)) self.range_func = (get_range_spectrogram if 'spec' in self.type else get_range)
[docs] def draw(self): """Read in all necessary data and generate a figure """ keys = [] # generate data for i, channel in enumerate(self.channels): fftkwargs = dict((key, self.fftparams[key][i]) for key in self.fftparams if self.fftparams[key][i] is not None) rangekwargs = dict((key, self.rangeparams[key][i]) for key in self.rangeparams if self.rangeparams[key][i] is not None) if self.state and not self.all_data: valid = self.state.active else: valid = SegmentList([self.span]) # replace range_func arg with correct method if 'range_func' in rangekwargs.keys(): rangekwargs['range_func'] = RAN_DICT[rangekwargs['range_func']] rlist = self.range_func(channel, valid, query=self.read, **fftkwargs, **rangekwargs) try: keys.append(str(rlist[0].channel)) except IndexError: keys.append(get_range_channel(channel, **rangekwargs)) # reset channel lists and generate plot channels = self.channels self.channels = keys out = super(RangePlotMixin, self).draw() self.channels = channels return out
[docs] class RangeDataPlot(RangePlotMixin, get_plot('timeseries')): type = 'range' defaults = get_plot('timeseries').defaults.copy() defaults.update(RangePlotMixin.defaults.copy()) defaults.update({ 'ylabel': 'Sensitive distance [Mpc]', })
register_plot(RangeDataPlot)
[docs] class RangeDataHistogramPlot(RangePlotMixin, get_plot('histogram')): type = 'range-histogram' defaults = get_plot('histogram').defaults.copy() defaults.update(RangePlotMixin.defaults.copy()) defaults.update({ 'xlabel': 'Sensitive distance [Mpc]', })
register_plot(RangeDataHistogramPlot)
[docs] class RangeCumulativeHistogramPlot(RangePlotMixin, get_plot('histogram')): type = 'range-cumulative-histogram' defaults = get_plot('histogram').defaults.copy() defaults.update(RangePlotMixin.defaults.copy()) defaults.update({ 'xlabel': 'Angle-averaged range [Mpc]', 'ylabel': 'Cumulative time duration', 'log': False, 'cumulative': True, 'density': True, 'range': (1, 'max'), })
register_plot(RangeCumulativeHistogramPlot)
[docs] class RangeSpectrogramDataPlot(RangePlotMixin, get_plot('spectrogram')): type = 'range-spectrogram' defaults = get_plot('spectrogram').defaults.copy() defaults.update(RangePlotMixin.defaults.copy()) defaults.update({ 'cmap': 'inferno', 'norm': 'linear', })
register_plot(RangeSpectrogramDataPlot)
[docs] class RangeSpectrumDataPlot(RangePlotMixin, get_plot('spectrum')): type = 'range-spectrum' data = 'spectrum' defaults = get_plot('spectrum').defaults.copy() defaults.update(RangePlotMixin.defaults.copy()) defaults.update({ 'xlabel': 'Frequency [Hz]', 'yscale': 'linear', })
register_plot(RangeSpectrumDataPlot)
[docs] class RangeCumulativeSpectrumDataPlot(RangePlotMixin, get_plot('spectrum')): type = 'cumulative-range-spectrum' data = 'spectrum' defaults = get_plot('spectrum').defaults.copy() defaults.update(RangePlotMixin.defaults.copy()) defaults.update({ 'xlabel': 'Frequency [Hz]', 'ylim': [0, 100], 'ylabel': 'Cumulative fraction of range [%]', 'yscale': 'linear', 'yticks': [0, 20, 40, 60, 80, 100], 'ytickmarks': [0, 20, 40, 60, 80, 100], })
register_plot(RangeCumulativeSpectrumDataPlot) # -- time-volume --------------------------------------------------------------
[docs] class SimpleTimeVolumeDataPlot(get_plot('segments')): """Time-series of the time-volume searched by an interferometer """ data = 'timeseries' type = 'time-volume' DRAW_PARAMS = get_plot('timeseries').DRAW_PARAMS defaults = get_plot('timeseries').defaults.copy() parse_plot_kwargs = get_plot('timeseries').parse_plot_kwargs def __init__(self, sources, *args, **kwargs): if isinstance(sources, str): sources = split_channels(sources) channels = sources[::2] flags = sources[1::2] get_plot('timeseries').__init__(self, channels, *args, **kwargs) self._allflags = [] self.flags = flags
[docs] @classmethod def from_ini(cls, *args, **kwargs): return get_plot('timeseries').from_ini(cls, *args, **kwargs)
@property def pid(self): try: return self._pid except AttributeError: self._pid = hash("".join(map(str, self.channels+self.flags))) return self.pid @pid.setter def pid(self, id_): self._pid = str(id_) @pid.deleter def pid(self): del self._pid
[docs] @staticmethod def calculate_time_volume(segments, range): try: ts = TimeSeries(numpy.zeros(range.size), xindex=range.times, unit='s') except IndexError: ts = TimeSeries(numpy.zeros(range.size), unit='s', x0=range.x0, dx=range.dx) dx = range.dx.value # override range units range.override_unit('Mpc') # use float, not LIGOTimeGPS for speed segments = type(segments)([type(s)(float(s[0]), float(s[1])) for s in segments]) def livetime_(t): return float(abs(SegmentList([Segment(t, t+dx)]) & segments)) livetime = numpy.vectorize(livetime_, otypes=[float]) ts[:] = livetime(ts.times.value) * ts.unit return (4/3. * pi * ts * range ** 3).to('Mpc^3 kyr')
[docs] def combined_time_volume(self, allsegments, allranges): # first remove any IFOs that have no range data at all empty = [i for i, r in enumerate(allranges) if not len(r.value)] for i in empty[::-1]: allsegments.pop(i) allranges.pop(i) # find the earliest time we have any data min_x0 = min([r.x0.value for r in allranges]) for i, r in enumerate(allranges): # pad all range time series that don't start at min_x0 # so that all time series have the same start time if r.x0.value > min_x0: missing = int((r.x0.value - min_x0) / r.dx.value) allranges[i] = r.pad((missing, 0)) try: combined_range = TimeSeries(numpy.zeros(allranges[0].size), xindex=allranges[0].times, unit='Mpc') except IndexError: combined_range = TimeSeries( numpy.zeros(allranges[0].size), unit='Mpc', x0=allranges[0].x0, dx=allranges[0].dx) # get coincident observing segments pairs = list(combinations(allsegments, 2)) coincident = SegmentList() for pair in pairs: coincident.extend(pair[0] & pair[1]) coincident = coincident.coalesce() # get effective network range values = [r.value for r in allranges] values = [min(nlargest(2, x)) for x in zip(*values)] size = min([r.size for r in allranges]) combined_range[:size] = values * combined_range.unit # compute time-volume return self.calculate_time_volume(coincident, combined_range)
[docs] def draw(self, outputfile=None): """Generate the figure for this plot """ plot = self.init_plot() ax = plot.axes[0] # get plotting arguments cumulative = self.pargs.pop('cumulative', False) plotargs = self.parse_plot_kwargs() legendargs = self.parse_legend_kwargs() # set ylabel if cumulative: self.pargs.setdefault( 'ylabel', 'Cumulative time-volume [Mpc$^3$ kyr]', ) else: self.pargs.setdefault( 'ylabel', 'Time-volume [Mpc$^3$ kyr]', ) # get data allsegs, allranges = ([], []) for channel, flag, pargs in zip(self.channels, self.flags, plotargs): pad = 0 if self.state and not self.all_data: valid = self.state.active pad = numpy.nan elif channel.sample_rate.value: valid = SegmentList([self.span.protract( 1/channel.sample_rate.value)]) else: valid = SegmentList([self.span]) data = get_timeseries( channel, valid, query=False).join(gap='pad', pad=pad) if not data.unit or data.unit.to_string() in ['', 'undef']: data.override_unit('Mpc') segments = get_segments(flag, valid, query=False) timevolume = self.calculate_time_volume(segments.active, data) if cumulative: ax.plot(timevolume.cumsum(), **pargs) else: ax.plot(timevolume, **pargs) allsegs.append(segments.active) allranges.append(data) # estimate combined time-volume if self.all_data and len(self.channels) > 1: pargs = plotargs[-1] pargs['color'] = '#000000' pargs['label'] = 'Combined' pargs['linestyle'] = '--' combined_timevolume = self.combined_time_volume( allsegs, allranges) if cumulative: ax.plot(combined_timevolume.cumsum(), **pargs) else: ax.plot(combined_timevolume, **pargs) # add horizontal lines to add for yval in self.pargs.get('hline', []): try: yval = float(yval) except ValueError: continue else: ax.plot([self.start, self.end], [yval, yval], linestyle='--', color='red') # customise plot self.apply_parameters(ax, **self.pargs) if (len(self.channels) > 1 or plotargs[0].get('label', None) in [texify(str(self.channels[0])), None]): ax.legend(**legendargs) # add extra axes and finalise self.add_state_segments(ax) self.add_future_shade() if ax.get_yscale() == 'log': ax.yaxis.set_major_locator(LogLocator()) else: ax.yaxis.set_major_locator(MaxNLocator(8)) ticks = ax.get_yticks() ax.yaxis.set_ticklabels(ticks) return self.finalize(outputfile=outputfile)
register_plot(SimpleTimeVolumeDataPlot)
[docs] class GWpyTimeVolumeDataPlot(RangePlotMixin, SimpleTimeVolumeDataPlot): """TimeVolumeDataPlot where the range is calculated on-the-fly """ type = 'strain-time-volume' _threadsafe = False
register_plot(GWpyTimeVolumeDataPlot)