#!/usr/bin/env python # Copyright (c) 2014, U Chun Lao All rights reserved. # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # perform a greedy search on lattice paths import sys import string import os import re import heapq import random import pickle import numpy import argparse from sklearn import linear_model, feature_extraction parser = argparse.ArgumentParser() parser.add_argument('-inputF', '-input', required=True, help='input lattice file') parser.add_argument('-outputF', '-output', default=False, help='output file') parser.add_argument('-branch', '-b', default=100, type=int, help='number of paths') parser.add_argument('-logReg', required=True, help='log regression model file') parser.add_argument('-enco', '-e', required=True, help='log regression model encoding') parser.add_argument('-vect', '-v', required=True, help='log regression model vectorizer') parser.add_argument('-t', default=30, help='max number of failures of generating new' + \ ' paths before giving up') parser.add_argument('-logErr', '-logtostderr', default=False, action='store_true', \ help='log to stderr') parser.add_argument('-debugF', '-debug', default=False, action='store_true', help='debug flag') parser.add_argument('-bias', default=0, type=int, help='output grouping file index bias') parser.add_argument('-grouping', default=False, help='grouping file') args = parser.parse_args(sys.argv[1:]) inputF = args.inputF output = args.outputF branch = args.branch logReg = args.logReg enco = args.enco vectF = args.vectF grouping = args.grouping tot = args.bias maxTrial = args.t logErr = args.logErr debugF = args.debugF # heap buffer size buffSize = 100 # heap buffer cutoff size cutoff = 100 # greedy flag greedyF = True entries = [] buff = [] lnum = 0 endPt = [] if inputF: istream = open(inputF, 'r') if logErr: sys.stderr.write('reading from %s\n' % (inputF)) else: istream = sys.stdin for l in istream: lnum += 1 # sys.stdout.write('%d: %s\n' % (lnum, l)) if len(l.strip()) > 1: buff.append(l.replace('\n', '')) else: # sys.stderr.write('blank line at %d\n' % (lnum)) if len(buff) > 0: endPt.append(lnum) # sys.stderr.write('entry %d ends at line %d\n' % (len(entries), lnum)) entries.append(buff) buff = [] if len(buff) > 0: endPt.append(lnum) entries.append(buff) if inputF: istream.close() if logErr: sys.stderr.write('Read %d lines, %d sentences\n' % (lnum, len(entries))) sys.stderr.write(('Selecting %d paths\n' % (branch)) if branch else \ 'Selecting all possible paths\n') # read in model if logReg: with open(logReg, 'rb') as lrF: model = pickle.load(lrF) if logErr: sys.stderr.write('Read in logistic regression model\n') # read in encoding file encoding = {} if enco: for l in open(enco, 'r'): spt = l.split(' ') encoding[spt[0]] = int(spt[1]) if logErr: sys.stderr.write('Read in encoding file\n') # read in vectorizer vect = None if vectF: with open(vectF, 'rb') as vF: vect = pickle.load(vF) if logErr: sys.stderr.write('Read in vectorizer file\n') # encode the given list of features, return the encoded features, if new # entry added to encoding dictionary, also return the new dict, False o/w def encode(feat, encoding): rtn = [] addNew = False for f in feat: if f in encoding: rtn.append(encoding[f]) else: rtn.append(len(encoding)) encoding[f] = len(encoding) addNew = True if addNew: return rtn, encoding else: return rtn, False # seperate features and encode them def encode2(inFeat, encoding): feat = {} extFeat = [] extList = [0, 3] crossFeat = {} addNew = False for i in range(len(inFeat)): sptFeat = inFeat[i].split() # seperate morpological features morFeat = [] # print sptFeat if '=' in sptFeat[4]: mor = sptFeat[4].split('|') for m in mor: morFeat.append(m.split('=')) spt, newEnc = encode(sptFeat[:4] + [m[1] for m in morFeat], encoding) if newEnc: encoding = newEnc addNew = True # store extracted features for later use extFeat.append([sptFeat[j] for j in extList]) feat['word-%d' % i] = spt[0] feat['lemma-%d' % i] = spt[1] feat['cpos-%d' % i] = spt[2] feat['pos-%d' % i] = spt[3] # process morpological features for j in range(len(morFeat)): feat['mor-%s-%d' % (morFeat[j][0], i)] = spt[4+j] # add combined features for i in range(len(inFeat)): for j in range(len(inFeat)): if i==j: continue for k1 in range(len(extFeat[i])): k2 = k1 while k2 < len(extFeat[j]): crossFeat['cross-%d,%d|%d,%d' % (i,k1,j,k2)] = '%s|%s' % (extFeat[i][k1] , extFeat[j][k2]) k2 += 1 # combFeat = feat['word-%d' % i] keys = crossFeat.keys() spt, newEnc = encode([crossFeat[c] for c in keys], encoding) if newEnc: encoding = newEnc addNew = True for i in range(len(crossFeat)): feat[keys[i]] = spt[i] return feat, (encoding if addNew else False) # end of encode def greedy(sent, mode, vect, encoding, prog=False): # print 'Greedy on' # rtn = [] # ends = {} nodes = {} probs = {} endNode = 0 addNew = False cyc = [] # form lattice nodes look up table for n in sent: spt = n.split('\t') head = int(spt[0]) tail = int(spt[1]) feat = '\t'.join(spt[2:-1]) if tail > endNode: endNode = tail # if it is a cycle, skip to the next if head == tail: cyc.append(head) continue if head in nodes: nodes[head].append((tail, feat)) else: nodes[head] = [(tail, feat)] if len(cyc) > 0: cycstr = '%d' % cyc[0] if len(cyc) > 1: for c in cyc[1:]: cycstr += ', %d' % c sys.stderr.write('%4d: %2d cycle at %s'.ljust(70) % (prog[0], len(cyc), cycstr)) sys.stderr.write('\n') path = ['-ROOT-\t_\t_\t_\t_'] # expand path for current in range(endNode): if prog: sys.stderr.write('overall: %4d / %4d, local: %3d / %3d \r' \ % (prog[0], prog[1], current, endNode)) sys.stderr.flush() best = None for n in nodes[current]: feat = (path[-1], n[1]) encFeat, newEnc = encode2(feat, encoding) if newEnc: encoding = newEnc addNew = True feaTup = tuple(encFeat) if feaTup in probs: p = probs[feaTup] else: p = model.predict_log_proba(vect.transform(encFeat))[0][1] probs[feaTup] = p if best is None or p > best[0]: best = (p, n[1]) path.append(best[1]) return [path], (encoding if addNew else False) print model.classes_ if outputF: ostream = open(outputF, 'w') else: ostream = sys.stdout if logErr: sys.stderr.write('Start writing to %s\n' % (outputF if outputF else 'stdout')) total = 0 sentCt = 0 splitteds = [] chkPt = len(entries) / 20 if chkPt < 1: chkPt = 1 for e in entries: if debugF: sys.stderr.write('%d at line %d\n' % (total, endPt[total])) total += 1 if greedyF: fe, newEnc = greedy(e, model, vect, encoding,\ (len(splitteds), len(entries))) else: fe, newEnc = beamN(e, model, vect, encoding, branch,\ (len(splitteds), len(entries))) if newEnc: encoding = newEnc splitted = 0 for s in fe: splitted += 1 tid = 0 for t in s[1:]: tid += 1 ostream.write('%d\t%s\t_\t_\t_\t_\n' % (tid, t)) if debugF: ostream.write('%f\n' % s[0]) ostream.write('\n') splitteds.append(splitted) if logErr: # and len(splitteds) % chkPt == 0: sys.stderr.write('progress: %4d / %4d \r' % (len(splitteds), len(entries))) sys.stderr.flush() if outputF: ostream.close() if grouping: with open(grouping, 'w') as f: for s in splitteds: f.write('%d\t%d\t%d\n' % (tot, tot+s-1, s)) tot += s if logErr: sys.stderr.write('\nWrites grouping information to %s' % grouping) if logErr: sys.stderr.write('\nprocessed %d sentences into ' % total) for s in splitteds: sys.stderr.write('%d ' % s) sys.stderr.write('sentences, %d in total\n' % (sum(splitteds)))