Source code for suncasa.utils.mstools

import numpy as np
from tqdm import tqdm
import os
from shutil import rmtree

from ..casa_compat import import_casatools,import_casatasks

[docs] tasks = import_casatasks('split', 'tclean', 'casalog', 'clearcal', 'gaincal')
[docs] split = tasks.get('split')
[docs] tclean = tasks.get('tclean')
[docs] casalog = tasks.get('casalog')
[docs] clearcal = tasks.get('clearcal')
[docs] gaincal = tasks.get('gaincal')
[docs] tools = import_casatools(['tbtool', 'mstool', 'qatool'])
[docs] tbtool = tools['tbtool']
[docs] mstool = tools['mstool']
[docs] qatool = tools['qatool']
[docs] tb = tbtool()
[docs] ms = mstool()
[docs] qa = qatool()
[docs] def get_bandinfo(msfile, spw=None, returnbdinfo=False, verbose=False): ''' get center frequencies of all spectral windows for msfile spw: [option] return the cfreq of spw. spw can be a a string or a list of string. The syntax of spw follows the standard spw Parameter in CASA if spw is not provided, return the cfreq of all spws in the msfile. return cfreqs is in GHz if returnbounds is True, return a dictionary including comprehensive freq information of the ms. ''' ms.open(msfile) spwInfo = ms.getspectralwindowinfo() nspw = len(spwInfo.keys()) reffreqs = [] bdwds = [] chanwds = [] nchans = [] for s in range(nspw): s_ = str(s) reffreqs.append(spwInfo[s_]['RefFreq']) bdwds.append(spwInfo[s_]['TotalWidth']) chanwds.append(spwInfo[s_]['ChanWidth']) nchans.append(spwInfo[s_]['NumChan']) reffreqs = np.array(reffreqs) / 1e9 bdwds = np.array(bdwds) / 1e9 chanwds = np.array(chanwds) / 1e9 nchans = np.array(nchans) cfreqs = reffreqs + bdwds / 2.0 - chanwds / 2.0 bdinfo = {'bounds_all': np.hstack((reffreqs, reffreqs[-1] + bdwds[-1])), 'cfreqs_all': cfreqs, 'bounds_all_lo': reffreqs, 'bounds_all_hi': reffreqs + bdwds, 'nchans': nchans} if spw: freqbounds_lo_spw = [] freqbounds_hi_spw = [] cfreqs_spw = [] for sp in spw: if verbose: print(f'Parsing spw {sp}...') try: staql = {'spw': sp} ms.selectinit(reset=True) ms.msselect(staql, onlyparse=True) ndx = ms.msselectedindices() chan_sel = ndx['channel'] bspw = chan_sel[0, 0] bchan = chan_sel[0, 1] espw = chan_sel[-1, 0] echan = chan_sel[-1, 2] bfreq = (spwInfo[str(bspw)]['Chan1Freq'] + spwInfo[str(bspw)]['ChanWidth'] * bchan) / 1e9 efreq = (spwInfo[str(espw)]['Chan1Freq'] + spwInfo[str(espw)]['ChanWidth'] * echan) / 1e9 cfreq = (bfreq + efreq) / 2. except ValueError: if verbose: print("Parsing spw {} failed. Aborting...".format(sp)) continue freqbounds_lo_spw.append(bfreq) freqbounds_hi_spw.append(efreq) cfreqs_spw.append(cfreq) cfreqs = np.array(cfreqs_spw) freqbounds_lo = np.array(freqbounds_lo_spw) freqbounds_hi = np.array(freqbounds_hi_spw) bdinfo['bounds_lo'] = freqbounds_lo bdinfo['bounds_hi'] = freqbounds_hi bdinfo['cfreqs'] = cfreqs ms.done() if returnbdinfo: return bdinfo else: return cfreqs
[docs] def get_bmsize(cfreq, refbmsize=70.0, reffreq=1.0, minbmsize=4.0): ''' get beamsize at frequencies definded by cfreq based on refbmsize at reffreq cfreq: input frequencies at GHz refbmsize: reference beam size in arcsec reffreq: reference frequency in GHz minbmsize: minimum beam size in arcsec ''' bmsize = refbmsize * reffreq / cfreq bmsize[bmsize < minbmsize] = minbmsize return bmsize
[docs] def get_trange(msfile): from astropy.time import Time tb.open(msfile) tr = np.array([tb.getcell('TIME', 0), tb.getcell('TIME', tb.nrows() - 1)]) / 24. / 3600. tb.close() return Time(tr, format='mjd')
[docs] def time2filename(msfile, timerange='', spw='', desc=False): from astropy.time import Time tb.open(msfile) starttim = Time(tb.getcell('TIME', 0) / 24. / 3600., format='mjd') endtim = Time(tb.getcell('TIME', tb.nrows() - 1) / 24. / 3600., format='mjd') tb.close() datstr = starttim.iso[:10] ms.open(msfile) metadata = ms.metadata() observatory = metadata.observatorynames()[0] ms.done() if timerange is None or timerange == '': starttim1 = starttim endtim1 = endtim else: (tstart, tend) = timerange.split('~') if tstart[2] == ':': starttim1 = Time(datstr + 'T' + tstart) endtim1 = Time(datstr + 'T' + tend) else: starttim1 = Time(qa.quantity(tstart, 'd')['value'], format='mjd') endtim1 = Time(qa.quantity(tend, 'd')['value'], format='mjd') if desc: tstr = starttim1.to_datetime().strftime( '{}_%Y%m%dT%H%M%S.%f'.format(observatory)) + endtim1.to_datetime().strftime('-%H%M%S.%f') else: midtime = Time((starttim1.mjd + endtim1.mjd) / 2., format='mjd') tstr = midtime.to_datetime().strftime('{}_%Y%m%dT%H%M%S.%f'.format(observatory)) if spw: spstr = 'spw{}'.format(spw.replace('~', '-')) filename = '.'.join([tstr, spstr]) else: filename = tstr return filename
[docs] def msclearhistory(msfile): '''Clears history in the a measurement sets file :param msfile: string The name of a measurement sets file :return: ''' tb.open(msfile + '/HISTORY', nomodify=False) nrows = tb.nrows() if nrows > 0: tb.removerows(range(nrows)) tb.close()
[docs] def clearflagrow(msfile, mode='clear'): ''' :param msfile: :param mode: FLAG_ROW operation default: 'clear': (default) clear the FLAG_ROW 'list': to list existing FLAG_ROW :return: ''' if mode == 'list': tb.open(msfile, nomodify=True) a = tb.getcol('FLAG_ROW') nfrows = np.sum(a) nrows = float(len(a)) print('{:0d} out of {:.0f} ({:.0f}%) rows are flagged in {}'.format(nfrows, nrows, nfrows / nrows * 100, os.path.basename(msfile))) elif mode == 'clear': tb.open(msfile, nomodify=False) a = tb.getcol('FLAG_ROW') a[:] = False tb.putcol('FLAG_ROW', a) print('reset successfully') tb.close()
[docs] def splitX(vis, datacolumn2='MODEL_DATA', **kwargs): import os """ Splits specific data columns from a CASA measurement set (MS) into a new MS file, overcoming the limitation of splitting multiple data columns directly with CASA's standard split function. A temporary MS is used during processing to ensure the additional column (the `datacolumn2`) has the same shape as the data/corrected column in the output MS. This temporary file is removed at the end. :param str vis: Path to the original measurement set. :param str datacolumn2: The name of the additional data column to be included in the split. Defaults to 'MODEL_DATA'. :param kwargs: Additional keyword arguments to be passed to the CASA `split` task. :return: Path to the output measurement set with the specified columns split. :rtype: str """ datacolumn2 = datacolumn2.upper() outmsfile = kwargs.get('outputvis', '').rstrip('/') # Clean up the output directory and associated flag versions if os.path.exists(outmsfile): rmtree(outmsfile) flagversions_path = f"{outmsfile}.flagversions" if os.path.exists(flagversions_path): rmtree(flagversions_path) # Perform the initial split to create the output MS with the data/corrected column into the data column split(vis=vis, **kwargs) # Prepare for the second split tmpms_file = f'{vis}.tmpms' kwargs2 = {k: v for k, v in kwargs.items() if k not in ['datacolumn', 'outputvis']} kwargs2.update({'outputvis': tmpms_file, 'datacolumn': datacolumn2.replace('_DATA', '')}) # Ensure temporary MS is removed if it exists if os.path.exists(tmpms_file): rmtree(tmpms_file) # Perform the second split to temporary MS split(vis=vis, **kwargs2) # Process the additional column from the temporary MS tb.open(tmpms_file) nrows = tb.nrows() data = [] for row in tqdm(range(nrows), desc=f'Extracting {datacolumn2} column', ascii=True): data.append(tb.getcell('DATA', row)) tb.close() # Add additional data column to the output MS clearcal(outmsfile, addmodel=True) tb.open(outmsfile, nomodify=False) for row in tqdm(range(nrows), desc=f'writing {datacolumn2} column', ascii=True): tb.putcell(datacolumn2, row, data[row]) tb.close() # Cleanup rmtree(tmpms_file) return outmsfile
[docs] def flagcaltboutliers(caltable, limit=[]): import numpy as np import numpy.ma as ma # def removeOutliers(x, outlierConstant): # a = np.array(x) # idx, = np.where(np.diff(np.sort(datamag[0, 0, :]))>) # upper_quartile = np.percentile(a, 80) # lower_quartile = np.percentile(a, 20) # IQR = (upper_quartile - lower_quartile) * outlierConstant # quartileSet = (lower_quartile - IQR, upper_quartile + IQR) # return ma.masked_outside(x, quartileSet[1], quartileSet[0]) if not os.path.exists(caltable): return 0 if isinstance(limit, list): if len(limit) == 2: tb.open(caltable, nomodify=False) # subt = tb.query("ANTENNA1==1 && SPECTRAL_WINDOW_ID=10") # data = subt.getcol('CPARAM') # flag = subt.getcol('FLAG') # spw = subt.getcol('SPECTRAL_WINDOW_ID') # datamag = np.abs(data) # mdatamag = ma.masked_outside(datamag, limit[0], limit[1]) # mask = np.logical_or(mdatamag.mask, flag) # dataidx1 = datamag<limit[0] # dataidx2 = datamag>limit[1] # mdatamag = ma.masked_array(mdatamag, mask) # mdatamag[0, 0, :] = removeOutliers(mdatamag[0, 0, :], 5) # mdatamag[1, 0, :] = removeOutliers(mdatamag[1, 0, :], 5) data = tb.getcol('CPARAM') flag = tb.getcol('FLAG') datamag = np.abs(data) dataidx1 = datamag < limit[0] dataidx2 = datamag > limit[1] flag[dataidx1] = True flag[dataidx2] = True tb.putcol('FLAG', flag) return 1 else: print('limit must have two elements. Aborted!') return 0 else: print('limit must be a list. Aborted!')
[docs] def modeltransfer(msfile, spw='', reference='XX', transfer='YY'): pol_dict = {'XX': 0, 'YY': 1, 'XY': 2, 'YX': 3} refidx = pol_dict[reference] trfidx = pol_dict[transfer] datams = mstool() datams.open(msfile, nomodify=False) if '~' in spw: sp0, sp1 = spw.split('~') for sp in range(int(sp0), int(sp1) + 1): staql = {'spw': str(sp)} datams.selectinit(reset=True) datams.msselect(staql) modeldata = datams.getdata(['model_data']) modeldata['model_data'][trfidx, ...] = modeldata['model_data'][refidx, ...] datams.putdata(modeldata) datams.close() else: datams.selectinit(reset=True) staql = {'spw': spw} datams.msselect(staql) modeldata = datams.getdata(['model_data']) modeldata['model_data'][trfidx, ...] = modeldata['model_data'][refidx, ...] datams.putdata(modeldata) datams.close()
[docs] def concat_slftb(tb_in=[], tb_out=None): if not tb_in: print('tb_in not provided. Abort...') if os.path.exists(tb_out): os.system('rm -r {}'.format(tb_out)) os.system('cp -r {} {}'.format(tb_in[0], tb_out)) tbdata = {} tb.open(tb_out) cols = tb.colnames() tb.close() cols.remove('WEIGHT') for col in cols: tbdata[col] = [] for tbidx, ctb in enumerate(tb_in): tb.open(ctb, nomodify=True) tim0 = tb.getcol(cols[0]) if len(tim0) == 0: continue else: for col in cols: if tbidx == 1 and col in ['CPARAM', 'PARAMERR', 'FLAG', 'SNR']: tbdata[col].append(tb.getcol(col)[::-1, ...]) else: tbdata[col].append(tb.getcol(col)) tb.close() if len(tbdata[cols[0]]) == 0: print('tables have no data. Return') return -1 else: for col in cols: if col in ['CPARAM', 'PARAMERR', 'FLAG', 'SNR']: tbdata[col] = np.concatenate(tbdata[col], axis=2) else: tbdata[col] = np.concatenate(tbdata[col]) tb.open(tb_out, nomodify=False) nrows = tb.nrows() nrows_new = len(tbdata[cols[0]]) tb.addrows(nrows_new - nrows) for col in cols: tb.putcol(col, tbdata[col]) tb.close() return tb_out
[docs] def gaincalXY(vis=None, caltable=None, pols='XXYY', msfileXY=None, gaintableXY=None, **kwargs): if pols == 'XXYY': pols = 'XX,YY' pols_ = pols.split(',') rm_msfileXY = False if msfileXY is None: rm_msfileXY = True msfileXY = {} for pol in pols_: msfileXY[pol] = '.'.join([vis, pol]) if os.path.exists(msfileXY[pol]): os.system('rm -rf {}'.format(msfileXY[pol])) splitX(vis=vis, outputvis=msfileXY[pol], correlation=pol, datacolumn='data', datacolumn2='MODEL_DATA') if gaintableXY is not None: if 'gaintable' in kwargs.keys(): kwargs.pop('gaintable') caltbXY = [] for pol in pols_: caltb_ = '.'.join([caltable, pol]) if gaintableXY is not None: kwargs['gaintable'] = gaintableXY[pol] gaincal(vis=msfileXY[pol], caltable=caltb_, **kwargs) caltbXY.append(caltb_) concat_slftb(caltbXY, caltable) if rm_msfileXY: for k, v in msfileXY.iteritems(): os.system('rm -rf {}'.format(v)) return
[docs] def getmodel(vis, spw=None): tb.open(vis, nomodify=True) subt = tb.query("DATA_DESC_ID==" + str(spw)) model_d = subt.getcol('MODEL_DATA') subt.done() tb.done() return model_d
[docs] def putmodel(vis, spw=None, model=None): tb.open(vis, nomodify=False) subt = tb.query("DATA_DESC_ID==" + str(spw)) model_d = subt.putcol('MODEL_DATA', model) subt.done() tb.done() return model_d
from tqdm import tqdm import time
[docs] def sort_polarization_order(msfile, outputvis): """ Sort the polarization order in a measurement set to match the conventional order (XX, XY, YX, YY). :param msfile: Path to the input measurement set file. defaults to None :type msfile: str (optional) :param outputvis: Path to the output measurement set file where the sorted data will be stored. defaults to None :type outputvis: str (optional) :raises OSError: If the output file cannot be created or copying fails. :raises Exception: If an error occurs while processing the CASA tables. :return: None :rtype: None """ # Remove existing outputvis if it exists t_start = time.time() if os.path.exists(outputvis): os.system(f'rm -rf {outputvis}') if os.system(f'cp -r {msfile} {outputvis}') != 0: raise OSError("Failed to copy the MS file to the output location.") try: tb.open(outputvis + '/POLARIZATION/', nomodify=False) corr_type = tb.getcol('CORR_TYPE') corr_order = np.argsort(corr_type[:, 0]) sorted_corr_type = corr_type[corr_order] tb.putcol('CORR_TYPE', sorted_corr_type) corr_product = tb.getcol('CORR_PRODUCT') sorted_corr_product = corr_product for i in range(corr_product.shape[0]): sorted_corr_product[i] = corr_product[i][corr_order] tb.putcol('CORR_PRODUCT', sorted_corr_product) except Exception as e: print("An error occurred while sorting the polarization order in POLARIZATION table:", e) raise finally: tb.close() try: tb.open(msfile) nrows = tb.nrows() data = [] for row in tqdm(range(nrows), desc=f'Extracting DATA column', ascii=True): data.append(tb.getcell('DATA', row)) tb.close() tb.open(outputvis, nomodify=False) for row in tqdm(range(nrows), desc=f'writing DATA column', ascii=True): tb.putcell('DATA', row, data[row][corr_order]) tb.close() except Exception as e: print("An error occurred while sorting the polarization order in DATA column:", e) raise finally: tb.close() t_end = time.time() print(f"Sorting the polarization order in the measurement set took {(t_end - t_start) / 60:.1f} minutes.") return outputvis