# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm. If not, see <http://www.gnu.org/licenses/>.
"""Definitions for range plots
"""
from math import pi
from heapq import nlargest
from itertools import combinations
import numpy
from matplotlib.ticker import (LogLocator, MaxNLocator)
import gwpy.astro
from gwpy.segments import (Segment, SegmentList)
from gwpy.timeseries import TimeSeries
from gwdetchar.plot import texify
from .registry import (get_plot, register_plot)
from .utils import hash
from ..data import (get_range_channel, get_range,
get_range_spectrogram, get_timeseries)
from ..segments import get_segments
from ..channels import split as split_channels
__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'
__credits__ = 'Alex Urban <alexander.urban@ligo.org>'
# -- utils --------------------------------------------------------------------
RAN_DICT = {
'sensemon_range': gwpy.astro.sensemon_range,
'sensemon_range_psd': gwpy.astro.sensemon_range_psd,
'inspiral_range': gwpy.astro.inspiral_range,
'inspiral_range_psd': gwpy.astro.inspiral_range_psd,
'burst_range': gwpy.astro.burst_range,
'burst_range_psd': gwpy.astro.burst_range_spectrum,
}
def _get_params(keys, pargs, nchans=1):
"""Return a `dict` of `list` of plot arguments for every channel
"""
params = {}
for key in keys:
try:
value = pargs.pop(key)
except KeyError:
continue
if not isinstance(value, (tuple, list)):
value = [value] * nchans
params[key] = value
return params
# -- sensitive range ----------------------------------------------------------
[docs]
class RangePlotMixin(object):
data = 'spectrogram'
_threadsafe = False
defaults = {
'snr': 8.0,
'stride': 60.,
'fftlength': 8,
'overlap': 4,
'fmin': 10,
}
def __init__(self, *args, **kwargs):
super(RangePlotMixin, self).__init__(*args, **kwargs)
self.fftparams = _get_params(
['stride', 'fftlength', 'overlap'],
self.pargs, nchans=len(self.channels))
self.rangeparams = _get_params(
['mass1', 'mass2', 'snr', 'energy', 'fmin', 'fmax', 'range_func'],
self.pargs, nchans=len(self.channels))
self.range_func = (get_range_spectrogram if
'spec' in self.type else get_range)
[docs]
def draw(self):
"""Read in all necessary data and generate a figure
"""
keys = []
# generate data
for i, channel in enumerate(self.channels):
fftkwargs = dict((key, self.fftparams[key][i]) for
key in self.fftparams if
self.fftparams[key][i] is not None)
rangekwargs = dict((key, self.rangeparams[key][i]) for
key in self.rangeparams if
self.rangeparams[key][i] is not None)
if self.state and not self.all_data:
valid = self.state.active
else:
valid = SegmentList([self.span])
# replace range_func arg with correct method
if 'range_func' in rangekwargs.keys():
rangekwargs['range_func'] = RAN_DICT[rangekwargs['range_func']]
rlist = self.range_func(channel, valid, query=self.read,
**fftkwargs, **rangekwargs)
try:
keys.append(str(rlist[0].channel))
except IndexError:
keys.append(get_range_channel(channel, **rangekwargs))
# reset channel lists and generate plot
channels = self.channels
self.channels = keys
out = super(RangePlotMixin, self).draw()
self.channels = channels
return out
[docs]
class RangeDataPlot(RangePlotMixin, get_plot('timeseries')):
type = 'range'
defaults = get_plot('timeseries').defaults.copy()
defaults.update(RangePlotMixin.defaults.copy())
defaults.update({
'ylabel': 'Sensitive distance [Mpc]',
})
register_plot(RangeDataPlot)
[docs]
class RangeDataHistogramPlot(RangePlotMixin, get_plot('histogram')):
type = 'range-histogram'
defaults = get_plot('histogram').defaults.copy()
defaults.update(RangePlotMixin.defaults.copy())
defaults.update({
'xlabel': 'Sensitive distance [Mpc]',
})
register_plot(RangeDataHistogramPlot)
[docs]
class RangeCumulativeHistogramPlot(RangePlotMixin, get_plot('histogram')):
type = 'range-cumulative-histogram'
defaults = get_plot('histogram').defaults.copy()
defaults.update(RangePlotMixin.defaults.copy())
defaults.update({
'xlabel': 'Angle-averaged range [Mpc]',
'ylabel': 'Cumulative time duration',
'log': False,
'cumulative': True,
'density': True,
'range': (1, 'max'),
})
register_plot(RangeCumulativeHistogramPlot)
[docs]
class RangeSpectrogramDataPlot(RangePlotMixin, get_plot('spectrogram')):
type = 'range-spectrogram'
defaults = get_plot('spectrogram').defaults.copy()
defaults.update(RangePlotMixin.defaults.copy())
defaults.update({
'cmap': 'inferno',
'norm': 'linear',
})
register_plot(RangeSpectrogramDataPlot)
[docs]
class RangeSpectrumDataPlot(RangePlotMixin, get_plot('spectrum')):
type = 'range-spectrum'
data = 'spectrum'
defaults = get_plot('spectrum').defaults.copy()
defaults.update(RangePlotMixin.defaults.copy())
defaults.update({
'xlabel': 'Frequency [Hz]',
'yscale': 'linear',
})
register_plot(RangeSpectrumDataPlot)
[docs]
class RangeCumulativeSpectrumDataPlot(RangePlotMixin, get_plot('spectrum')):
type = 'cumulative-range-spectrum'
data = 'spectrum'
defaults = get_plot('spectrum').defaults.copy()
defaults.update(RangePlotMixin.defaults.copy())
defaults.update({
'xlabel': 'Frequency [Hz]',
'ylim': [0, 100],
'ylabel': 'Cumulative fraction of range [%]',
'yscale': 'linear',
'yticks': [0, 20, 40, 60, 80, 100],
'ytickmarks': [0, 20, 40, 60, 80, 100],
})
register_plot(RangeCumulativeSpectrumDataPlot)
# -- time-volume --------------------------------------------------------------
[docs]
class SimpleTimeVolumeDataPlot(get_plot('segments')):
"""Time-series of the time-volume searched by an interferometer
"""
data = 'timeseries'
type = 'time-volume'
DRAW_PARAMS = get_plot('timeseries').DRAW_PARAMS
defaults = get_plot('timeseries').defaults.copy()
parse_plot_kwargs = get_plot('timeseries').parse_plot_kwargs
def __init__(self, sources, *args, **kwargs):
if isinstance(sources, str):
sources = split_channels(sources)
channels = sources[::2]
flags = sources[1::2]
get_plot('timeseries').__init__(self, channels, *args, **kwargs)
self._allflags = []
self.flags = flags
[docs]
@classmethod
def from_ini(cls, *args, **kwargs):
return get_plot('timeseries').from_ini(cls, *args, **kwargs)
@property
def pid(self):
try:
return self._pid
except AttributeError:
self._pid = hash("".join(map(str, self.channels+self.flags)))
return self.pid
@pid.setter
def pid(self, id_):
self._pid = str(id_)
@pid.deleter
def pid(self):
del self._pid
[docs]
@staticmethod
def calculate_time_volume(segments, range):
try:
ts = TimeSeries(numpy.zeros(range.size), xindex=range.times,
unit='s')
except IndexError:
ts = TimeSeries(numpy.zeros(range.size), unit='s',
x0=range.x0, dx=range.dx)
dx = range.dx.value
# override range units
range.override_unit('Mpc')
# use float, not LIGOTimeGPS for speed
segments = type(segments)([type(s)(float(s[0]), float(s[1])) for
s in segments])
def livetime_(t):
return float(abs(SegmentList([Segment(t, t+dx)]) & segments))
livetime = numpy.vectorize(livetime_, otypes=[float])
ts[:] = livetime(ts.times.value) * ts.unit
return (4/3. * pi * ts * range ** 3).to('Mpc^3 kyr')
[docs]
def combined_time_volume(self, allsegments, allranges):
# first remove any IFOs that have no range data at all
empty = [i for i, r in enumerate(allranges) if not len(r.value)]
for i in empty[::-1]:
allsegments.pop(i)
allranges.pop(i)
# find the earliest time we have any data
min_x0 = min([r.x0.value for r in allranges])
for i, r in enumerate(allranges):
# pad all range time series that don't start at min_x0
# so that all time series have the same start time
if r.x0.value > min_x0:
missing = int((r.x0.value - min_x0) / r.dx.value)
allranges[i] = r.pad((missing, 0))
try:
combined_range = TimeSeries(numpy.zeros(allranges[0].size),
xindex=allranges[0].times, unit='Mpc')
except IndexError:
combined_range = TimeSeries(
numpy.zeros(allranges[0].size), unit='Mpc',
x0=allranges[0].x0, dx=allranges[0].dx)
# get coincident observing segments
pairs = list(combinations(allsegments, 2))
coincident = SegmentList()
for pair in pairs:
coincident.extend(pair[0] & pair[1])
coincident = coincident.coalesce()
# get effective network range
values = [r.value for r in allranges]
values = [min(nlargest(2, x)) for x in zip(*values)]
size = min([r.size for r in allranges])
combined_range[:size] = values * combined_range.unit
# compute time-volume
return self.calculate_time_volume(coincident, combined_range)
[docs]
def draw(self, outputfile=None):
"""Generate the figure for this plot
"""
plot = self.init_plot()
ax = plot.axes[0]
# get plotting arguments
cumulative = self.pargs.pop('cumulative', False)
plotargs = self.parse_plot_kwargs()
legendargs = self.parse_legend_kwargs()
# set ylabel
if cumulative:
self.pargs.setdefault(
'ylabel',
'Cumulative time-volume [Mpc$^3$ kyr]',
)
else:
self.pargs.setdefault(
'ylabel',
'Time-volume [Mpc$^3$ kyr]',
)
# get data
allsegs, allranges = ([], [])
for channel, flag, pargs in zip(self.channels, self.flags, plotargs):
pad = 0
if self.state and not self.all_data:
valid = self.state.active
pad = numpy.nan
elif channel.sample_rate.value:
valid = SegmentList([self.span.protract(
1/channel.sample_rate.value)])
else:
valid = SegmentList([self.span])
data = get_timeseries(
channel, valid, query=False).join(gap='pad', pad=pad)
if not data.unit or data.unit.to_string() in ['', 'undef']:
data.override_unit('Mpc')
segments = get_segments(flag, valid, query=False)
timevolume = self.calculate_time_volume(segments.active, data)
if cumulative:
ax.plot(timevolume.cumsum(), **pargs)
else:
ax.plot(timevolume, **pargs)
allsegs.append(segments.active)
allranges.append(data)
# estimate combined time-volume
if self.all_data and len(self.channels) > 1:
pargs = plotargs[-1]
pargs['color'] = '#000000'
pargs['label'] = 'Combined'
pargs['linestyle'] = '--'
combined_timevolume = self.combined_time_volume(
allsegs, allranges)
if cumulative:
ax.plot(combined_timevolume.cumsum(), **pargs)
else:
ax.plot(combined_timevolume, **pargs)
# add horizontal lines to add
for yval in self.pargs.get('hline', []):
try:
yval = float(yval)
except ValueError:
continue
else:
ax.plot([self.start, self.end], [yval, yval],
linestyle='--', color='red')
# customise plot
self.apply_parameters(ax, **self.pargs)
if (len(self.channels) > 1 or plotargs[0].get('label', None) in
[texify(str(self.channels[0])), None]):
ax.legend(**legendargs)
# add extra axes and finalise
self.add_state_segments(ax)
self.add_future_shade()
if ax.get_yscale() == 'log':
ax.yaxis.set_major_locator(LogLocator())
else:
ax.yaxis.set_major_locator(MaxNLocator(8))
ticks = ax.get_yticks()
ax.yaxis.set_ticklabels(ticks)
return self.finalize(outputfile=outputfile)
register_plot(SimpleTimeVolumeDataPlot)
[docs]
class GWpyTimeVolumeDataPlot(RangePlotMixin, SimpleTimeVolumeDataPlot):
"""TimeVolumeDataPlot where the range is calculated on-the-fly
"""
type = 'strain-time-volume'
_threadsafe = False
register_plot(GWpyTimeVolumeDataPlot)