#!/usr/bin/env python
"""
Copy from Marin Anderson 3/8/2019
"""
from __future__ import division
import numpy as np
import casacore.tables as pt
import os,argparse
import numpy.ma as ma
import logging
from scipy.ndimage import filters
[docs]def flag_bad_chans(msfile: str, band: str, usedatacol=False, generate_plot=False, apply_flag=False, crosshand=False,
uvcut_m: float = None):
"""Flag bad channels.
Finds remaining bad channels and flags those in the measurement set. Also writes out text file that lists
flags that were applied.
Args:
msfile: measurement set to flag.
band: spectral window.
usedatacol: If True, uses DATA column, else use CORRECTED_DATA.
generate_plot: generate a plot or not.
apply_flag: Whether to apply the flags.
crosshand: If true, it will use the XY and YX correlations when determining flags.
Otherwise, it will ignore the flags that are in flaglist[:,1] and flaglist[:,2].
uvcut_m: uvcut in meters before doing thresholding to suppress short baseline flux
"""
with pt.table(msfile, readonly=False) as t:
tcross = t.query('ANTENNA1!=ANTENNA2')
if usedatacol:
datacol = tcross.getcol('DATA')
else:
datacol = tcross.getcol('CORRECTED_DATA')
flagcol = tcross.getcol('FLAG')
if uvcut_m:
uvw = tcross.getcol('UVW')
uvdist = np.sqrt( uvw[:,0]**2. + uvw[:,1]**2. )
indsbyuvdist = np.where(uvdist > uvcut_m)
datacol = datacol[indsbyuvdist]
flagcol = flagcol[indsbyuvdist]
datacolamp = np.abs(datacol)
datacolamp_mask = ma.masked_array(datacolamp, mask=flagcol, fill_value=np.nan)
maxamps = np.ma.max(datacolamp_mask, axis=0)
meanamps = np.ma.mean(datacolamp_mask, axis=0)
maxamps_medfilt = filters.median_filter(maxamps, size=(25,1)) #10,1))
maxamps_norm = maxamps / maxamps_medfilt
maxamps_norm_stdfilt = filters.generic_filter(maxamps_norm, np.std, size=(25,1))
threshold_vec = np.array([10,6,6,10])
maxamps_lower = 1 - threshold_vec*np.ma.min(maxamps_norm_stdfilt, axis=0)
maxamps_upper = 1 + threshold_vec*np.ma.min(maxamps_norm_stdfilt, axis=0)
meanamps_stdfilt = filters.generic_filter(meanamps, np.std, size=(25,1))
flaglist = np.where( (maxamps_norm < maxamps_lower) | (maxamps_norm > maxamps_upper) |
(meanamps > np.ma.median(meanamps, axis=0)+100*np.ma.min(meanamps_stdfilt, axis=0)) )
if not crosshand:
flaglist = np.unique(flaglist[0][np.where( (flaglist[1] == 0) | (flaglist[1] == 3) )])
else:
flaglist = np.unique(flaglist[0])
#import pylab
#pylab.ion()
#pylab.plot(maxamps_norm[:,0], '.', color='Blue')
#pylab.plot(maxamps_norm[:,3], '.', color='Green')
#pylab.plot(maxamps_norm[:,1], '.', color='Orange')
#pylab.plot(maxamps_norm[:,2], '.', color='Magenta')
#pylab.hlines(maxamps_lower[0], 0, 108, color='blue')
#pylab.hlines(maxamps_upper[0], 0, 108, color='blue')
#pylab.hlines(maxamps_lower[3], 0, 108, color='green')
#pylab.hlines(maxamps_upper[3], 0, 108, color='green')
#pylab.hlines(maxamps_lower[1], 0, 108, color='orange')
#pylab.hlines(maxamps_upper[1], 0, 108, color='orange')
#pylab.hlines(maxamps_lower[2], 0, 108, color='magenta')
#pylab.hlines(maxamps_upper[2], 0, 108, color='magenta')
#pylab.plot(flaglist,maxamps_norm[flaglist,0], '.', color='Red')
#pylab.plot(flaglist,maxamps_norm[flaglist,3], '.', color='Red')
#pylab.plot(flaglist,maxamps_norm[flaglist,1], '.', color='Red')
#pylab.plot(flaglist,maxamps_norm[flaglist,2], '.', color='Red')
#pylab.grid('on')
#pylab.figure()
#pylab.plot(meanamps[:,0], '.', color='Blue')
#pylab.plot(meanamps[:,3], '.', color='Green')
#pylab.plot(meanamps[:,1], '.', color='orange')
#pylab.plot(meanamps[:,2], '.', color='magenta')
#pylab.hlines(np.median(meanamps[:,0])+100*np.min(meanamps_stdfilt[:,0]), 0, 108, color='Blue')
#pylab.hlines(np.median(meanamps[:,3])+100*np.min(meanamps_stdfilt[:,3]), 0, 108, color='Green')
#pylab.hlines(np.median(meanamps[:,1])+100*np.min(meanamps_stdfilt[:,1]), 0, 108, color='Orange')
#pylab.hlines(np.median(meanamps[:,2])+100*np.min(meanamps_stdfilt[:,2]), 0, 108, color='Magenta')
#pylab.grid('on')
#import pdb
#pdb.set_trace()
#################################################
#this is for testing purposes only
#generate plot of visibilities for quick check of how well flagging performed
if generate_plot:
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.figure(figsize=(5,10))
chans = np.arange(0,109)
for chan in chans:
if chan not in flaglist:
chanpts = np.zeros(len(datacolamp_mask[:,chan,0]))+chan
plt.plot(datacolamp_mask[:,chan,0],chanpts, '.', color='Blue', markersize=0.5)
plt.plot(datacolamp_mask[:,chan,3],chanpts, '.', color='Green', markersize=0.5)
plt.ylim([0,108])
plt.ylabel('channel')
plt.xlabel('Amp')
plt.gca().invert_yaxis()
plotfile = os.path.splitext(os.path.abspath(msfile))[0]+'.png'
plt.savefig(plotfile)
################################################
logging.info('Flaglist size is %i' % flaglist.size)
if flaglist.size > 0:
# turn flaglist into text file of channel flags
textfile = os.path.splitext(os.path.abspath(msfile))[0]+'.chans'
chans = np.arange(0,109)
chanlist = chans[flaglist]
with open(textfile, 'w') as f:
for chan in chanlist:
f.write('%02d:%03d\n' % (np.int(band),chan))
# write flags into FLAG column
if apply_flag:
logging.info('Applying the changes to the measurement set.')
flagcol_altered = t.getcol('FLAG')
flagcol_altered[:,flaglist,:] = 1
t.putcol('FLAG', flagcol_altered)
return msfile
[docs]def main():
parser = argparse.ArgumentParser(description="Flag bad channels and write out list of channels that were \
flagged into text file of same name as ms. MUST BE RUN ON \
SINGLE SUBBAND MS.")
parser.add_argument("msfile", help="Measurement set.")
parser.add_argument("band", help="Subband number.")
parser.add_argument("--usedatacol", action="store_true", default=False, help="Grab DATA column, not CORRECTED_DATA.")
parser.add_argument('--plot', action='store_true', default=False, help='Generate plot of amp vs channel.')
parser.add_argument('--apply-flag', action='store_true', default=False, help='Apply flags to measurement set.')
parser.add_argument('--crosshand', action='store_true', default=False, help='Use the cross-hand visibilities also.')
parser.add_argument('--uvcut_m', action='store', type=float, default=None, help='Only use visibilities greater than {uvcut_m} in meters when determining channel flags. Default is None.')
args = parser.parse_args()
flag_bad_chans(args.msfile, args.band, usedatacol=args.usedatacol,
generate_plot=args.plot, apply_flag=args.apply_flag,
crosshand=args.crosshand, uvcut_m=args.uvcut_m)
if __name__ == '__main__':
main()