Source code for orca.flagging.flag_bad_chans

#!/usr/bin/env python

"""
Copy from Marin Anderson 3/8/2019
"""
from __future__ import division
import numpy as np
import casacore.tables as pt
import os,argparse
import numpy.ma as ma
import logging
from scipy.ndimage import filters


[docs]def flag_bad_chans(msfile: str, band: str, usedatacol=False, generate_plot=False, apply_flag=False, crosshand=False, uvcut_m: float = None): """Flag bad channels. Finds remaining bad channels and flags those in the measurement set. Also writes out text file that lists flags that were applied. Args: msfile: measurement set to flag. band: spectral window. usedatacol: If True, uses DATA column, else use CORRECTED_DATA. generate_plot: generate a plot or not. apply_flag: Whether to apply the flags. crosshand: If true, it will use the XY and YX correlations when determining flags. Otherwise, it will ignore the flags that are in flaglist[:,1] and flaglist[:,2]. uvcut_m: uvcut in meters before doing thresholding to suppress short baseline flux """ with pt.table(msfile, readonly=False) as t: tcross = t.query('ANTENNA1!=ANTENNA2') if usedatacol: datacol = tcross.getcol('DATA') else: datacol = tcross.getcol('CORRECTED_DATA') flagcol = tcross.getcol('FLAG') if uvcut_m: uvw = tcross.getcol('UVW') uvdist = np.sqrt( uvw[:,0]**2. + uvw[:,1]**2. ) indsbyuvdist = np.where(uvdist > uvcut_m) datacol = datacol[indsbyuvdist] flagcol = flagcol[indsbyuvdist] datacolamp = np.abs(datacol) datacolamp_mask = ma.masked_array(datacolamp, mask=flagcol, fill_value=np.nan) maxamps = np.ma.max(datacolamp_mask, axis=0) meanamps = np.ma.mean(datacolamp_mask, axis=0) maxamps_medfilt = filters.median_filter(maxamps, size=(25,1)) #10,1)) maxamps_norm = maxamps / maxamps_medfilt maxamps_norm_stdfilt = filters.generic_filter(maxamps_norm, np.std, size=(25,1)) threshold_vec = np.array([10,6,6,10]) maxamps_lower = 1 - threshold_vec*np.ma.min(maxamps_norm_stdfilt, axis=0) maxamps_upper = 1 + threshold_vec*np.ma.min(maxamps_norm_stdfilt, axis=0) meanamps_stdfilt = filters.generic_filter(meanamps, np.std, size=(25,1)) flaglist = np.where( (maxamps_norm < maxamps_lower) | (maxamps_norm > maxamps_upper) | (meanamps > np.ma.median(meanamps, axis=0)+100*np.ma.min(meanamps_stdfilt, axis=0)) ) if not crosshand: flaglist = np.unique(flaglist[0][np.where( (flaglist[1] == 0) | (flaglist[1] == 3) )]) else: flaglist = np.unique(flaglist[0]) #import pylab #pylab.ion() #pylab.plot(maxamps_norm[:,0], '.', color='Blue') #pylab.plot(maxamps_norm[:,3], '.', color='Green') #pylab.plot(maxamps_norm[:,1], '.', color='Orange') #pylab.plot(maxamps_norm[:,2], '.', color='Magenta') #pylab.hlines(maxamps_lower[0], 0, 108, color='blue') #pylab.hlines(maxamps_upper[0], 0, 108, color='blue') #pylab.hlines(maxamps_lower[3], 0, 108, color='green') #pylab.hlines(maxamps_upper[3], 0, 108, color='green') #pylab.hlines(maxamps_lower[1], 0, 108, color='orange') #pylab.hlines(maxamps_upper[1], 0, 108, color='orange') #pylab.hlines(maxamps_lower[2], 0, 108, color='magenta') #pylab.hlines(maxamps_upper[2], 0, 108, color='magenta') #pylab.plot(flaglist,maxamps_norm[flaglist,0], '.', color='Red') #pylab.plot(flaglist,maxamps_norm[flaglist,3], '.', color='Red') #pylab.plot(flaglist,maxamps_norm[flaglist,1], '.', color='Red') #pylab.plot(flaglist,maxamps_norm[flaglist,2], '.', color='Red') #pylab.grid('on') #pylab.figure() #pylab.plot(meanamps[:,0], '.', color='Blue') #pylab.plot(meanamps[:,3], '.', color='Green') #pylab.plot(meanamps[:,1], '.', color='orange') #pylab.plot(meanamps[:,2], '.', color='magenta') #pylab.hlines(np.median(meanamps[:,0])+100*np.min(meanamps_stdfilt[:,0]), 0, 108, color='Blue') #pylab.hlines(np.median(meanamps[:,3])+100*np.min(meanamps_stdfilt[:,3]), 0, 108, color='Green') #pylab.hlines(np.median(meanamps[:,1])+100*np.min(meanamps_stdfilt[:,1]), 0, 108, color='Orange') #pylab.hlines(np.median(meanamps[:,2])+100*np.min(meanamps_stdfilt[:,2]), 0, 108, color='Magenta') #pylab.grid('on') #import pdb #pdb.set_trace() ################################################# #this is for testing purposes only #generate plot of visibilities for quick check of how well flagging performed if generate_plot: import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt plt.figure(figsize=(5,10)) chans = np.arange(0,109) for chan in chans: if chan not in flaglist: chanpts = np.zeros(len(datacolamp_mask[:,chan,0]))+chan plt.plot(datacolamp_mask[:,chan,0],chanpts, '.', color='Blue', markersize=0.5) plt.plot(datacolamp_mask[:,chan,3],chanpts, '.', color='Green', markersize=0.5) plt.ylim([0,108]) plt.ylabel('channel') plt.xlabel('Amp') plt.gca().invert_yaxis() plotfile = os.path.splitext(os.path.abspath(msfile))[0]+'.png' plt.savefig(plotfile) ################################################ logging.info('Flaglist size is %i' % flaglist.size) if flaglist.size > 0: # turn flaglist into text file of channel flags textfile = os.path.splitext(os.path.abspath(msfile))[0]+'.chans' chans = np.arange(0,109) chanlist = chans[flaglist] with open(textfile, 'w') as f: for chan in chanlist: f.write('%02d:%03d\n' % (np.int(band),chan)) # write flags into FLAG column if apply_flag: logging.info('Applying the changes to the measurement set.') flagcol_altered = t.getcol('FLAG') flagcol_altered[:,flaglist,:] = 1 t.putcol('FLAG', flagcol_altered) return msfile
[docs]def main(): parser = argparse.ArgumentParser(description="Flag bad channels and write out list of channels that were \ flagged into text file of same name as ms. MUST BE RUN ON \ SINGLE SUBBAND MS.") parser.add_argument("msfile", help="Measurement set.") parser.add_argument("band", help="Subband number.") parser.add_argument("--usedatacol", action="store_true", default=False, help="Grab DATA column, not CORRECTED_DATA.") parser.add_argument('--plot', action='store_true', default=False, help='Generate plot of amp vs channel.') parser.add_argument('--apply-flag', action='store_true', default=False, help='Apply flags to measurement set.') parser.add_argument('--crosshand', action='store_true', default=False, help='Use the cross-hand visibilities also.') parser.add_argument('--uvcut_m', action='store', type=float, default=None, help='Only use visibilities greater than {uvcut_m} in meters when determining channel flags. Default is None.') args = parser.parse_args() flag_bad_chans(args.msfile, args.band, usedatacol=args.usedatacol, generate_plot=args.plot, apply_flag=args.apply_flag, crosshand=args.crosshand, uvcut_m=args.uvcut_m)
if __name__ == '__main__': main()