From fa78fe7e06a2bc61cbc5201d743cc110f7da53dd Mon Sep 17 00:00:00 2001 From: vieenrose Date: Sat, 27 Apr 2019 05:25:09 +0200 Subject: [PATCH] add automatic time reference tier detection. it will be used in the case that no one in the predefine list of time reference tier names is avaliable in the Praat TextGrid obj --- exporter.py | 310 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 178 insertions(+), 132 deletions(-) diff --git a/exporter.py b/exporter.py index 10276f0..4ded577 100644 --- a/exporter.py +++ b/exporter.py @@ -35,37 +35,37 @@ def get_encoding(filepath): # extend original TextGrid reader to support praat Collection / Analor .or class TextGridPlus(pympi.Praat.TextGrid): - + def extractTextGridFromAnalorFile(self,ifile): - + SuccessOrNot = False - + # not process Analor file when javaobj is not avaliable if not javaobj_installed: return SuccessOrNot - + try: marshaller = javaobj.JavaObjectUnmarshaller(ifile) except IOError: ifile.seek(0, 0) return SuccessOrNot - + while True: # get one object pobj = marshaller.readObject() if pobj == 'FIN' or \ - pobj == '' : + pobj == '' : break if pobj == 'F0': self.xmin, self.xmax = marshaller.readObject() - + # check if is at the tiers' header if pobj == 'TIRES': - - # get tier number + + # get tier number tier_num = marshaller.readObject() tier_num = struct.unpack('>i',tier_num)[0] - + while tier_num : # get the metadata of tier tlims = marshaller.readObject() @@ -74,18 +74,18 @@ def extractTextGridFromAnalorFile(self,ifile): mots = marshaller.readObject() bornes = marshaller.readObject() nomGuide = marshaller.readObject() - + # translation between 2 type naming # between Analor and Praat version - if typ == 'INTERVALLE' : + if typ == 'INTERVALLE' : tier_type = 'IntervalTier' - elif typ == 'POINT' : + elif typ == 'POINT' : tier_type = 'TextTier' - else : + else : raise Exception('Tiertype does not exist.') - + # form a tier - tier = pymip.Praat.Tier(0, 0, name=nom, tier_type=tier_type) + tier = pymip.Praat.Tier(0, 0, name=nom, tier_type=tier_type) self.tiers.append(tier) tier.xmin = tlims[0] tier.xmax = tlims[-1] @@ -99,14 +99,14 @@ def extractTextGridFromAnalorFile(self,ifile): raise Exception('Tiertype does not exist.') # uncount the number of tiers remain to process - if tier_num >0: + if tier_num >0: tier_num -= 1; - + SuccessOrNot = True - + ifile.seek(0, 0) return SuccessOrNot - + def from_file(self, ifile, codec='ascii'): """Read textgrid from stream. @@ -114,9 +114,9 @@ def from_file(self, ifile, codec='ascii'): :param str codec: Text encoding for the input. Note that this will be ignored for binary TextGrids. """ - + # extract TextGrid form Analor file (.or) - if self.extractTextGridFromAnalorFile(ifile) : + if self.extractTextGridFromAnalorFile(ifile) : pass # read a Textgrid or extract TextGrid from Collection in Binary Format elif ifile.read(12) == b'ooBinaryFile': @@ -139,7 +139,7 @@ def bin2str(ifile): # jump to the begining of the embedded TextGrid object if ifile.read(ord(ifile.read(1))) == b'Collection': # skip oo type self.jump2TextGridBin(ifile, codec) - + self.xmin = struct.unpack('>d', ifile.read(8))[0] self.xmax = struct.unpack('>d', ifile.read(8))[0] ifile.read(1) # skip @@ -188,7 +188,7 @@ def nn(ifile, pat): #print(name) #debug tier = pympi.Praat.Tier(0, 0, name=name, tier_type=tier_type) self.tiers.append(tier) - + tier.xmin = float(nn(ifile, regfloat)) tier.xmax = float(nn(ifile, regfloat)) for i in range(int(nn(ifile, regint))): @@ -204,14 +204,14 @@ def nn(ifile, pat): def jump2TextGridBin(self, ifile, codec='ascii', keyword = b'\x08TextGrid'): binstr = b'' - while ifile: + while ifile: binstr += ifile.read(1) - if len(binstr) > len(keyword): + if len(binstr) > len(keyword): binstr = binstr[1:] - if binstr == keyword : + if binstr == keyword : break lg = struct.unpack('>h', ifile.read(2))[0] - if lg == -1 : + if lg == -1 : lg = lg.astype('>H') objname = ifile.read(lg).decode('ascii') # skip embeded oo name @@ -219,10 +219,10 @@ def insert_to_basename(filename, inserted, new_ext_name = None): basename, extension = os.path.splitext(filename) if new_ext_name: extension = u'.'+new_ext_name return basename + inserted + extension - -def deb_print(x) : + +def deb_print(x) : if DEBUG : print(x) - + # source : https://stackoverflow.com/questions/2460177/edit-distance-in-python def edit_distance(s1, s2): m=len(s1)+1 @@ -237,18 +237,18 @@ def edit_distance(s1, s2): tbl[i,j] = min(tbl[i, j-1]+1, tbl[i-1, j]+1, tbl[i-1, j-1]+cost) return tbl[i,j] - + def distance(s1, s2) : # retirer des signes de marcro qui ne sont pas présentes dans le tier de ref. - macrosyntax_signs = re.compile(r"[\#\&\(\)\[\]\/\|\+\s\<\>]") + macrosyntax_signs = re.compile(r"[\#\&\(\)\[\]\/\|\+\s\<\>]") s1 = re.sub(macrosyntax_signs, "", s1.lower()) s2 = re.sub(macrosyntax_signs, "", s2.lower()) dist = edit_distance(s1, s2[:len(s1)]) return dist - -def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : - + +def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : + sent = ' '.join(tokens) intvs = refTier.get_all_intervals() ref_tokens = [intv[-1] for intv in intvs] @@ -259,7 +259,7 @@ def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : best_begin_ref_sent = '' best_end_ref_sent = '' width = 2 * len(tokens) - + # détection du début temporel if upperbound < 0: # interprete negative upper bound as unbounded case upperbound = len(ref_tokens) @@ -270,9 +270,9 @@ def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : # adapt real width if necessary try : ref_tokens_sampled = ref_tokens[n:n+width] except IndexError: ref_tokens_sampled = ref_tokens[n:] - # check if the current token represnts a pause + # check if the current token represnts a pause if ref_tokens[n] == pauseSign or not(ref_tokens[n]) : continue # interdiction d'aligner le début de la phrase sur une pause ou un vide - + # search the begining ref_sent = ' '.join(ref_tokens_sampled) dist = distance(sent, ref_sent) @@ -282,9 +282,9 @@ def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : best_begin_ref_sent = ref_sent #deb_print('\t@findTimes best distance: {}'.format(best_dist)) #deb_print("\t@findTimes best begin sentence ({}) found : '{}'".format(best_begin_n, best_begin_ref_sent)) - + tmin = intvs[best_begin_n][0] # begining time of the starting interval - + # détection de la vraie fin temporelle best_dist = -1 best_sent = '' @@ -293,12 +293,12 @@ def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : end_n = best_begin_n + width ref_sent = ' '.join(ref_tokens[best_begin_n:end_n]) dist = distance(sent[::-1], ref_sent[::-1]) - if best_dist < 0 or dist <= best_dist : + if best_dist < 0 or dist <= best_dist : best_dist = dist best_end_n = end_n best_sent = ref_sent width -= 1 - + # verify if dist < 10% of sentence length deb_print("\t@findTimes sent to match : '{}'".format(sent)) if best_dist > thld * (len(sent)**1.1): @@ -310,14 +310,15 @@ def findTimes (tokens, refTier, lowerbound, upperbound = -1, thld = 0.1) : cursor_out = best_end_n deb_print("\t@findTimes sent found : '{}'".format(best_sent, tmin, tmax)) - return [tmin, tmax, cursor_out] - + #print(best_dist)#debug + return [tmin, tmax, cursor_out, best_dist] + def one_to_many_pairing (file1, files2, thld = 5): matched = '' maxlen = -1 doublon = False - + for file2 in files2 : nonenone,nonenone,match_len = \ difflib.SequenceMatcher(None, file1.lower(), file2.lower()).\ @@ -333,42 +334,120 @@ def one_to_many_pairing (file1, files2, thld = 5): if doublon : matched = '' return matched - + def make_paires(files1, files2): # fine 1-to-1 file pair pairs = [] for f1 in files1 : f2 = one_to_many_pairing(f1, files2) if f2 : - if f1 == one_to_many_pairing(f2, files1) : + if f1 == one_to_many_pairing(f2, files1) : pairs.append((f1,f2)) return pairs - + def show_list_of_file_pair (conll_tg_pairs, err_cnt = None, enc_dict = None, reverse = False): - if reverse: + if reverse: conll_tg_pairs = conll_tg_pairs[::-1] for n,p in enumerate(conll_tg_pairs): conll, tg = p string_to_display = u'{}:\t{:5s}: {}\n\t{:5s}: {}'.format(n,'CoNLL-U',conll,'TextGrid',tg) # encoding - if enc_dict : + if enc_dict : if tg in enc_dict.keys(): enc = enc_dict[tg] - if enc: + if enc: string_to_display += ' [{}]\n'.format(enc) # error count if err_cnt: if conll in err_cnt.keys(): - num_err = err_cnt[conll] - if num_err: + num_err = err_cnt[conll] + if num_err: string_to_display += '\n\tnumber of errors: {}\n'.format(num_err) string_to_display += '\n' print(string_to_display) +def core_routine(conll,srcCol,pauseSign,dest,ref, num_sent_to_read = -1): + # initialization + tokens = [] + sentId = 0 + pauseId = 0 + cursor = 0 + err_num = 0 + dist_tot = 0 + + # boucle de lecture + for n, row in enumerate(conll) : + + # les métadonnées + if row and len(row) < 10 : + metadata = row[0] + #deb_print("L{} mdata[:50] '{}'".format(n,metadata[:50])) + continue + + # token dans une phrase + if row : + token = row[srcCol - 1] + # récolte des tokens + if token.strip() != pauseSign : + tokens.append(token) + #deb_print("L{} tokens[-5:-1] '{}'".format(n,tokens[-5:-1])) + else : + #deb_print("L{} pause no.{} '{}'".format(n,pauseId, token)) + pauseId += 1 + + # saute de ligne à la frontière des phrases + else : + sent = ' '.join(tokens) + deb_print("L{} sentence no.{} '{}'".format(n,sentId,sent)) + + # try a local search from cursor to end of time with by default thld. + [begin,end, cursor_out, best_dist] = findTimes(tokens,ref, lowerbound=cursor, upperbound=cursor+50,thld = 0.10) + if cursor_out >= cursor : + cursor = cursor_out + deb_print("L{} local (begin,end) = ({:8.3f},{:8.3f})".format(n,begin,end)) + + # écrire le contenu dans le tier de destination + try: + dest.add_interval(begin=begin, end=end, value=sent, check=True) + except Exception as e: + print("Error @ L{} of the CoNLL : {}".format(n,e)) + err_num+=1 + + else: + # try a global search but with a more strict threshold for distance + [begin,end, cursor_out, best_dist] = findTimes(tokens,ref, lowerbound=0, upperbound=-1, thld = 0.05) + if cursor_out >= 0 : + #deb_print("L{} global (begin,end) = ({:8.3f},{:8.3f})".format(n,begin,end)) + + # écrire le contenu dans le tier de destination + try: + dest.add_interval(begin=begin, end=end, value=sent, check=True) + deb_print("L{} global (begin,end) = ({:8.3f},{:8.3f})".format(n,begin,end)) + except Exception as e: + print("Error @ L{} of the CoNLL : {}".format(n,e)) + err_num+=1 + else: + print("Search Fails @ L{} of the CoNLL".format(n)) + err_num+=1 + + # early break if number of sentences to read is reached + if sentId > num_sent_to_read and num_sent_to_read > 0 : break + # early exit() if too much dense errors + if err_num > 30 and err_num > 0.6 * sentId: + print('Warning: early exit of core routine with {} as ref. tier'.format(ref.name)) + return None + + # préparation à la prochaine phrase + dist_tot += best_dist; + sentId += 1 + tokens = [] + + return err_num,dist_tot + if __name__ == '__main__': - + # filelists / tiernames / constants # creattion of a frendly command-line interface using argparse @@ -406,24 +485,21 @@ def show_list_of_file_pair (conll_tg_pairs, err_cnt = None, enc_dict = None, rev #print("\tFichier TG traité : {}\n".format(inTgfile)) conll_path = args.conll_in+'/'+inconllFile inTg_path = args.praat_in+'/'+inTgfile - - #lecture du fichier tabulaire (CoNLL-U) - conll = csv.reader(open(conll_path, 'r'), delimiter='\t', quotechar='\\') - + print('\t{:s} {:s}'.format('<-',conll_path)) # detection of textgrid file encoding:utf-8, ascii, etc. enc[inTgfile] = get_encoding(inTg_path) print('\t{:s} {:s} [{}]'.format('<-',inTg_path, enc[inTgfile] if enc[inTgfile] else 'unknown')) outputTg_path = args.praat_out+'/'+insert_to_basename(inTgfile,'_UPDATED','TextGrid') print('\t{:s} {:s} [{}]'.format('->',outputTg_path, 'binary')) - + try: tg = TextGridPlus(file_path=inTg_path, codec=enc[inTgfile]) #lecture du fichier textgrid (Praat) - + except Exception as e: print('Error: {}'.format(e)) continue - + # handel diff. reference tier names ref = None for refTierName in refTierNames : @@ -432,76 +508,46 @@ def show_list_of_file_pair (conll_tg_pairs, err_cnt = None, enc_dict = None, rev break except IndexError: pass - if not ref: print('Error: cannot find a good reference tier for time alignement!');continue - - dest = tg.add_tier(destTierName) #tier de destination ('tx') - - # initialization - tokens = [] - sentId = 0 - pauseId = 0 - cursor = 0 - - # boucle de lecture - for n, row in enumerate(conll) : - - # les métadonnées - if row and len(row) < 10 : - metadata = row[0] - #deb_print("L{} mdata[:50] '{}'".format(n,metadata[:50])) - continue - - # token dans une phrase - if row : - token = row[srcCol - 1] - # récolte des tokens - if token.strip() != pauseSign : - tokens.append(token) - #deb_print("L{} tokens[-5:-1] '{}'".format(n,tokens[-5:-1])) - else : - #deb_print("L{} pause no.{} '{}'".format(n,pauseId, token)) - pauseId += 1 - - # saute de ligne à la frontière des phrases - else : - sent = ' '.join(tokens) - deb_print("L{} sentence no.{} '{}'".format(n,sentId,sent)) - - # try a local search from cursor to end of time with by default thld. - [begin,end, cursor_out] = findTimes(tokens,ref, lowerbound=cursor, upperbound=cursor+50,thld = 0.10) - if cursor_out >= cursor : - cursor = cursor_out - deb_print("L{} local (begin,end) = ({:8.3f},{:8.3f})".format(n,begin,end)) - - # écrire le contenu dans le tier de destination - try: - dest.add_interval(begin=begin, end=end, value=sent, check=True) - except Exception as e: - print("Error @ L{} of the CoNLL : {}".format(n,e)) - err[inconllFile]+=1 - - else: - # try a global search but with a more strict threshold for distance - [begin,end, cursor_out] = findTimes(tokens,ref, lowerbound=0, upperbound=-1, thld = 0.05) - if cursor_out >= 0 : - #deb_print("L{} global (begin,end) = ({:8.3f},{:8.3f})".format(n,begin,end)) - - # écrire le contenu dans le tier de destination - try: - dest.add_interval(begin=begin, end=end, value=sent, check=True) - deb_print("L{} global (begin,end) = ({:8.3f},{:8.3f})".format(n,begin,end)) - except Exception as e: - print("Error @ L{} of the CoNLL : {}".format(n,e)) - err[inconllFile]+=1 - else: - print("Search Fails @ L{} of the CoNLL".format(n)) - err[inconllFile]+=1 - - # préparation à la prochaine phrase - sentId += 1 - tokens = [] - - #path = "./{}/{}".format(out_rep, outputTgFile) + if ref: + dest = tg.add_tier(destTierName) #tier de destination ('tx_new') + err_num,best_dist=core_routine(conll,srcCol,pauseSign,dest,ref) + else: + + err_nums = collections.Counter() + dists = collections.Counter() + + all_tier_names = [tier.name for tier in tg.get_tiers()] + for tierName in all_tier_names : + tg.remove_tier(destTierName) + dest = tg.add_tier(destTierName) #tier de destination ('tx_new') + ref = tg.get_tier(tierName) + #lecture du fichier tabulaire (CoNLL-U) + with open(conll_path, 'r') as f: + conllReader = csv.reader(f, delimiter='\t', quotechar='\\') + # test for 10 sentences CoNLL to get an idea about the error rate of each tier + # as time ref. tier + ret = core_routine(conllReader,srcCol,pauseSign,dest,ref, num_sent_to_read=10) + if ret: err_nums[ref.name],dists[ref.name] = ret + else: continue + + if not err_nums or not dists: + print('Error: cannot find a good reference tier !') + continue + + # remove the output tier created during test + tg.remove_tier(destTierName) + dest = tg.add_tier(destTierName) #tier de destination ('tx_new') + # get the best matched tier for time ref. use with least accumulated edit distance + best_ref_name, best_dist = dists.most_common()[-1] + #print(dists) # debug + # final output based on the rigth ref. tier + print('Info: detect \'{}\' as time reference tier'.format(best_ref_name)) + ref = tg.get_tier(best_ref_name) #tier de repères temporels ('mot') + with open(conll_path, 'r') as f: + conllReader = csv.reader(f, delimiter='\t', quotechar='\\') + err_num,best_dist=core_routine(conllReader,srcCol,pauseSign,dest,ref) + + err[inconllFile]=err_num tg.to_file(outputTg_path, mode='binary', codec='utf-8') print("\n\nDONE.\n\n")