# -*- coding: utf-8 -*-
# Copyright (C) Duncan Macleod (2013)
#
# This file is part of GWSumm.
#
# GWSumm is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GWSumm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GWSumm. If not, see <http://www.gnu.org/licenses/>.
"""Definitions for event trigger plots
"""
import re
from collections import OrderedDict
from itertools import cycle
from numpy import isinf
from astropy.units import Quantity
from gwpy.detector import (Channel, ChannelList)
from gwpy.plot.gps import GPSTransform
from gwpy.plot.utils import (color_cycle, marker_cycle)
from gwpy.segments import SegmentList
from gwdetchar.plot import texify
from ... import globalv
from ...channels import get_channel
from ...data import (get_timeseries, add_timeseries)
from ..registry import (get_plot, register_plot)
from ...triggers import (get_triggers, get_time_column)
from ...utils import re_cchar
from ..utils import (get_column_string, hash)
__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'
TimeSeriesDataPlot = get_plot('timeseries')
LATEX_OPERATOR = {
'>=': r'\geq',
'<=': r'\leq',
}
[docs]
class TriggerPlotMixin(object):
"""Mixin to overwrite `channels` property for trigger plots
We don't need to get channel data for trigger plots.
"""
def __init__(self, *args, **kwargs):
self.filterstr = kwargs.pop('filterstr', None)
super(TriggerPlotMixin, self).__init__(*args, **kwargs)
@property
def allchannels(self):
"""List of all unique channels for this plot
"""
chans = set([re.split(r'[#@]', str(c), 1)[0] for c in self._channels])
return ChannelList(map(Channel, chans))
@property
def pid(self):
try:
return self._pid
except AttributeError:
chans = "".join(map(str, self.channels))
filts = "".join(map(str, [
getattr(c, 'filter', getattr(c, 'frequency_response', ''))
for c in self.channels]))
if self.filterstr:
filts += self.filterstr
self._pid = hash(chans + filts)
return self.pid
[docs]
class TriggerDataPlot(TriggerPlotMixin, TimeSeriesDataPlot):
"""Standard event trigger plot
"""
type = 'triggers'
data = 'triggers'
defaults = TimeSeriesDataPlot.defaults.copy()
defaults.update({
'x': 'time',
'y': 'snr',
'color': None,
'edgecolor': 'face',
'facecolor': None,
'marker': 'o',
's': 20,
'vmin': None,
'vmax': None,
'clim': None,
'cmap': 'YlGnBu',
'logcolor': False,
'colorlabel': None,
})
def __init__(self, channels, start, end, state=None, outdir='.',
etg=None, **kwargs):
if (len(channels) == 1) and ('title' not in kwargs):
kwargs['title'] = texify('%s (%s)' % (str(channels[0]), etg))
super(TriggerDataPlot, self).__init__(channels, start, end,
state=state, outdir=outdir,
**kwargs)
self.etg = etg
self.columns = [self.pargs.pop(c) for c in ('x', 'y', 'color')]
@property
def pid(self):
"""Unique identifier for this `TriggerDataPlot`.
Extends the standard `TimeSeriesDataPlot` pid with the ETG
and each of the column names.
"""
try:
return self._pid
except AttributeError:
super(TriggerDataPlot, self).pid
self._pid += '_%s' % re_cchar.sub('_', self.etg)
for column in self.columns:
if column:
self._pid += '_%s' % re_cchar.sub('_', str(column))
self._pid = self._pid.upper()
return self.pid
@pid.setter
def pid(self, id_):
self._pid = str(id_)
[docs]
def draw(self):
# get columns
xcolumn, ycolumn, ccolumn = self.columns
# initialise figure
plot = self.init_plot()
ax = plot.gca()
ax.grid(visible=True, which='both')
# work out labels
labels = self.pargs.pop('labels', self.channels)
if isinstance(labels, str):
labels = labels.split(',')
labels = [str(s).strip('\n ') for s in labels]
# get colouring params
cmap = self.pargs.pop('cmap')
clim = self.pargs.pop('clim', self.pargs.pop('colorlim', None))
cnorm = 'log' if self.pargs.pop('logcolor', False) else None
clabel = self.pargs.pop('colorlabel', None)
no_loudest = self.pargs.pop('no-loudest', False) is not False
loudest_by = self.pargs.pop('loudest-by', None)
# get plot arguments
plotargs = []
for i in range(len(self.channels)):
plotargs.append(dict())
# get plot arguments
for key in ['vmin', 'vmax', 'edgecolor', 'facecolor', 'cmap', 's',
'marker', 'rasterized', 'sortbycolor']:
try:
val = self.pargs.pop(key)
except KeyError:
continue
if key == 'facecolor' and len(self.channels) > 1 and val is None:
val = color_cycle()
if key == 'marker' and len(self.channels) > 1 and val is None:
val = marker_cycle()
elif isinstance(val, (list, tuple, cycle)):
val = cycle(val)
else:
val = cycle([val] * len(self.channels))
for i in range(len(self.channels)):
plotargs[i][key] = next(val)
# add data
valid = SegmentList([self.span])
if self.state and not self.all_data:
valid &= self.state.active
ntrigs = 0
for channel, label, pargs in zip(self.channels, labels, plotargs):
try:
channel = get_channel(channel)
except ValueError:
pass
if '#' in str(channel) or '@' in str(channel):
key = '%s,%s' % (str(channel),
self.state and str(self.state) or 'All')
else:
key = str(channel)
table = get_triggers(key, self.etg, valid, query=False)
if self.filterstr is not None:
table = table.filter(self.filterstr)
ntrigs += len(table)
# access channel parameters for limits
for c, column in zip(('x', 'y', 'c'), (xcolumn, ycolumn, ccolumn)):
if not column:
continue
# hack for SnglBurst frequency nonsense
if column in ['peak_frequency', 'central_freq']:
column = 'frequency'
# set x and y in plotargs
param = '%s_range' % column
lim = '%slim' % c
if (getattr(channel, param, None) is not None
and c in ('x', 'y')):
self.pargs.setdefault(lim, getattr(channel, param))
if isinstance(self.pargs[lim], Quantity):
self.pargs[lim] = self.pargs[lim].value
# set clim separately
elif hasattr(channel, param):
if not clim:
clim = getattr(channel, param)
ax.scatter(table[xcolumn], table[ycolumn],
c=table[ccolumn] if ccolumn else None,
label=label, **pargs)
# customise plot
legendargs = self.parse_legend_kwargs(markerscale=3)
if len(self.channels) == 1:
self.pargs.setdefault('title', texify(
'%s (%s)' % (str(self.channels[0]), self.etg)))
for axis in ('x', 'y'): # prevent zeros on log scale
scale = getattr(ax, 'get_{0}scale'.format(axis))()
lim = getattr(ax, 'get_{0}lim'.format(axis))()
if scale == 'log' and lim[0] <= 0 and not ntrigs:
getattr(ax, 'set_{0}lim'.format(axis))(1, 10)
self.apply_parameters(ax, **self.pargs)
# correct log-scale empty axes
if any(map(isinf, ax.get_ylim())):
ax.set_ylim(0.1, 10)
# add colorbar
if ccolumn:
if not ntrigs:
ax.scatter([1], [1], c=[1], visible=False)
ax.colorbar(cmap=cmap, clim=clim, norm=cnorm, label=clabel)
if len(self.channels) == 1 and len(table) and not no_loudest:
columns = [x for x in
(loudest_by or ccolumn or ycolumn, xcolumn, ycolumn,
ccolumn) if x is not None]
self.add_loudest_event(ax, table, *columns, fontsize='large')
if len(self.channels) > 1:
ax.legend(**legendargs)
# add state segments
if isinstance(ax.xaxis.get_transform(), GPSTransform):
self.add_state_segments(ax)
self.add_future_shade()
# finalise
return self.finalize()
[docs]
def add_loudest_event(self, ax, table, rank, *columns, **kwargs):
# get loudest row
idx = table[rank].argmax()
row = table[idx]
x = float(row[columns[0]])
y = float(row[columns[1]])
# clip loudest event to axes limits
xlim = ax.get_xlim()
ylim = ax.get_ylim()
x1 = max(min(x, xlim[1]), xlim[0])
y1 = max(min(y, ylim[1]), ylim[0])
if x1 != x or y1 != y: # loudest event is out of view
facecolor = 'pink'
clipon = False
else:
facecolor = 'gold'
clipon = True
# mark loudest row with star
coll = ax.scatter([x1], [y1], marker='*', zorder=1000,
facecolor=facecolor, edgecolor='black',
s=200, clip_on=clipon)
# get text
txt = []
for col in OrderedDict.fromkeys((rank,) + columns): # unique ordered
# format column name
colstr = get_column_string(col)
# format row value
try:
valstr = f"{row[col]:.2f}".rstrip('.0')
except ValueError: # not float()able
valstr = str(row[col])
txt.append(f'{colstr} = {valstr}')
# get position for new text
try:
pos = kwargs.pop('position')
except KeyError: # user didn't specify, set default above title
tpos = ax.title.get_position()
pos = [tpos[0], tpos[1]+0.09]
# parse text kwargs
text_kw = { # defaults
'transform': ax.transAxes,
'verticalalignment': 'bottom',
'horizontalalignment': 'center',
}
text_kw.update(kwargs)
if 'ha' in text_kw: # handle short versions or alignment params
text_kw['horizontalalignment'] = text_kw.pop('ha')
if 'va' in text_kw:
text_kw['verticalalignment'] = text_kw.pop('va')
# add text
text = ax.text(pos[0], pos[1],
f"Loudest event: {', '.join(txt)}",
**text_kw)
return coll, text
register_plot(TriggerDataPlot)
[docs]
class TriggerTimeSeriesDataPlot(TimeSeriesDataPlot):
"""Custom time-series plot to handle discontiguous `TimeSeries`.
"""
type = 'trigger-timeseries'
data = 'triggers'
[docs]
def draw(self):
"""Read in all necessary data, and generate the figure.
"""
plot = self.init_plot()
ax = plot.gca()
# work out labels
labels = self.pargs.pop('labels', self.channels)
if isinstance(labels, str):
labels = labels.split(',')
labels = [str(s).strip('\n ') for s in labels]
# add data
for label, channel in zip(labels, self.channels):
label = texify(label)
if self.state and not self.all_data:
valid = self.state.active
else:
valid = SegmentList([self.span])
data = get_timeseries(channel, valid, query=False)
# handle no timeseries
if not len(data):
ax.plot([0], [0], visible=False, label=label)
continue
# plot time-series
color = None
for ts in data:
# double-check log scales
if self.logy:
ts.value[ts.value == 0] = 1e-100
if color is None:
line = ax.plot(ts, label=label)[0]
color = line.get_color()
else:
ax.plot(ts, color=color, label=None)
# allow channel data to set parameters
if hasattr(data[0].channel, 'amplitude_range'):
self.pargs.setdefault('ylim',
data[0].channel.amplitude_range)
# add horizontal lines to add
for yval in self.pargs['hline']:
try:
yval = float(yval)
except ValueError:
continue
else:
ax.plot([self.start, self.end], [yval, yval],
linestyle='--', color='red')
# customise plot
legendargs = self.parse_legend_kwargs()
self.apply_parameters(ax, **self.pargs)
if len(self.channels) > 1:
ax.legend(**legendargs)
# finalise
self.add_state_segments(ax)
return self.finalize()
register_plot(TriggerTimeSeriesDataPlot)
[docs]
class TriggerHistogramPlot(TriggerPlotMixin, get_plot('histogram')):
"""HistogramPlot from a LIGO_LW Table
"""
type = 'trigger-histogram'
data = 'triggers'
def __init__(self, *args, **kwargs):
super(TriggerHistogramPlot, self).__init__(*args, **kwargs)
self.etg = self.pargs.pop('etg')
self.column = self.pargs.pop('column')
@property
def pid(self):
try:
return self._pid
except AttributeError:
etg = re_cchar.sub('_', self.etg).upper()
self._pid = '%s_%s' % (etg, super(TriggerHistogramPlot, self).pid)
if self.column:
self._pid += '_%s' % re_cchar.sub('_', self.column).upper()
return self.pid
@pid.setter
def pid(self, id_):
self._pid = str(id_)
[docs]
def draw(self):
"""Get data and generate the figure.
"""
# get histogram parameters
plot = self.init_plot()
ax = plot.gca()
# extract histogram arguments
histargs = self.parse_plot_kwargs()
legendargs = self.parse_legend_kwargs()
# add data
data = []
livetime = []
for channel in self.channels:
try:
channel = get_channel(channel)
except ValueError:
pass
if self.state and not self.all_data:
valid = self.state.active
else:
valid = SegmentList([self.span])
if '#' in str(channel) or '@' in str(channel):
key = '%s,%s' % (str(channel),
self.state and str(self.state) or 'All')
else:
key = str(channel)
table_ = get_triggers(key, self.etg, valid, query=False)
if self.filterstr is not None:
table_ = table_.filter(self.filterstr)
livetime.append(float(abs(table_.meta['segments'])))
data.append(table_[self.column])
# allow channel data to set parameters
if hasattr(channel, 'amplitude_range'):
self.pargs.setdefault('xlim', channel.amplitude_range)
# plot
for arr, d, pargs in zip(data, livetime, histargs):
# set range if not given
if pargs.get('range') is None:
pargs['range'] = self._get_range(
d,
# use range from first dataset if already calculated
range=histargs[0].get('range'),
# use xlim if manually set (user or INI)
xlim=None if ax.get_autoscalex_on() else ax.get_xlim(),
)
pargs.setdefault('label', None)
if pargs.get('log', True):
pargs.setdefault('bottom', 1e-200)
ax.hist(arr, **pargs)
# tight scale the axes
try:
d = pargs.pop('orientation', 'vertical')
except NameError:
pass
else:
if d == 'vertical':
ax.autoscale_view(tight=True, scaley=False)
elif d == 'horizontal':
ax.autoscale_view(tight=True, scalex=False)
# customise plot
self.apply_parameters(ax, **self.pargs)
if len(self.channels) > 1:
ax.legend(**legendargs)
# finalise
return self.finalize()
register_plot(TriggerHistogramPlot)
[docs]
class TriggerRateDataPlot(TriggerPlotMixin, TimeSeriesDataPlot):
"""TimeSeriesDataPlot of trigger rate.
"""
type = 'trigger-rate'
data = 'triggers'
defaults = TimeSeriesDataPlot.defaults.copy()
defaults.update({
'column': None,
'legend-bbox_to_anchor': (1., 1.),
'legend-loc': 'upper left',
'legend-markerscale': 3,
'legend-frameon': False,
'ylabel': 'Rate [Hz]',
})
def __init__(self, *args, **kwargs):
if 'stride' not in kwargs:
raise ValueError("'stride' must be configured for all rate plots.")
if 'column' in kwargs and 'bins' not in kwargs:
raise ValueError("'bins' must be configured for rate plots if "
"'column' is given.")
super(TriggerRateDataPlot, self).__init__(*args, **kwargs)
self.etg = self.pargs.pop('etg')
self.column = self.pargs.pop('column')
@property
def pid(self):
try:
return self._pid
except AttributeError:
etg = re_cchar.sub('_', self.etg).upper()
pid = '%s_%s' % (etg, super(TriggerRateDataPlot, self).pid)
if self.column:
self.pid += '_%s' % re_cchar.sub('_', self.column).upper()
return pid
@pid.setter
def pid(self, id_):
self._pid = str(id_)
[docs]
def draw(self):
"""Read in all necessary data, and generate the figure.
"""
# get rate arguments
stride = self.pargs.pop('stride')
if self.column:
cname = get_column_string(self.column)
bins = self.pargs.pop('bins')
operator = self.pargs.pop('operator', '>=')
try:
opstr = LATEX_OPERATOR[operator]
except KeyError:
opstr = str(operator)
else:
bins = ['_']
# work out labels
labels = self.pargs.pop('labels', None)
if isinstance(labels, str):
labels = labels.split(',')
elif labels is None and self.column and len(self.channels) > 1:
labels = []
for channel, bin_ in [(c, b) for c in self.channels for b in bins]:
labels.append(r' '.join([channel, '$%s$' % opstr,
str(bin_)]))
self.pargs.setdefault('legend-title', cname)
elif labels is None and self.column:
labels = [r' '.join(['$%s$' % opstr, str(b)]) for b in bins]
self.pargs.setdefault('legend-title', cname)
elif labels is None:
labels = self.channels
self.pargs['labels'] = [str(s).strip('\n ') for s in labels]
# get time column
tcol = self.pargs.pop('timecolumn', None)
# generate data
keys = []
for channel in self.channels:
if self.state and not self.all_data:
valid = self.state.active
else:
valid = SegmentList([self.span])
if '#' in str(channel) or '@' in str(channel):
key = '%s,%s' % (str(channel),
str(self.state) if self.state else 'All')
else:
key = str(channel)
table_ = get_triggers(key, self.etg, valid, query=False)
if self.filterstr is not None:
table_ = table_.filter(self.filterstr)
tcol_ = tcol or get_time_column(table_, self.etg)
if self.column:
rates = list(table_.binned_event_rates(
stride, self.column, bins, operator=operator,
start=self.start, end=self.end, timecolumn=tcol_).values())
else:
rates = [table_.event_rate(stride, start=self.start,
end=self.end, timecolumn=tcol_)]
for bin, rate in zip(bins, rates):
rate.channel = channel
keys.append('%s_%s_EVENT_RATE_%s_%s'
% (str(channel), str(self.etg),
str(self.column), bin))
if keys[-1] not in globalv.DATA:
add_timeseries(rate, keys[-1])
# reset channel lists and generate time-series plot
channels = self.channels
outputfile = self.outputfile
self.channels = keys
out = super(TriggerRateDataPlot, self).draw(outputfile=outputfile)
self.channels = channels
return out
register_plot(TriggerRateDataPlot)