https://github.com/recski/HunTag
Raw File
Tip revision: ac681ac97f8677f25c29060fc8a168d68aca22d9 authored by Gábor Recski on 18 January 2016, 08:52:12 UTC
Update README.md
Tip revision: ac681ac
eval.py
#!/usr/bin/python
#eval.py comes with HunTag and is capable of evaluating precision and recall of sequential tagging output. It can also compute various other statistics. When run with the -c option, its operation is identical to the perl script conlleval.pl provided for the ConLL shared tasks chunking and NER-tagging.
import sys
import optfunc
def getChunksFromCorp(taggedFile, goldField, autoField, mode, strict):
    
    goldChunks = getChunksFromColumn(taggedFile, goldField, mode, strict)
    autoChunks = getChunksFromColumn(taggedFile, autoField, mode, strict)
    
    return (goldChunks, autoChunks)

def printError(lineNum):
    print 'illicit tag in line '+str(lineNum)
    
def getChunksFromColumn(taggedFile, field, mode, strict):    
    chunks = []
    chunkStart = None
    chunkType = None
    
    for c,line in enumerate(taggedFile):
        if line=='\n':
            continue
	try:
            tag = line.split()[field]
        except IndexError:
            print
            sys.stderr.write('invalid number of fields in line '+str(c)+'\n'+line)
            sys.exit(-1)
        
        if tag[0] == 'B':
            if strict:
                if mode == 'BIE1':
                    if chunkStart!=None:
                        printError(c)
                        sys.exit(-1)
            
            
            if chunkStart != None:
                chunks.append( (chunkStart, c-1, chunkType) )
                
            chunkStart = c
            chunkType = tag[2:]
            
        elif tag[0]=='I':
            if strict:    
                if chunkStart==None or chunkType!=tag[2:]:
                    printError(c)
                    sys.exit(-1)
        
        elif tag[0]=='E':
            if strict:
                if mode !='BIE1' or chunkStart==None or chunkType!=tag[2:]:
                    printError(c)
                    sys.exit(-1)
            
            if chunkStart==None:
                chunkStart=c
                chunkType=tag[2:]
                
            chunks.append( (chunkStart, c, chunkType) )
            chunkStart=None
            chunkType=None
        
        elif tag[0]=='1':
            if strict:
                if mode!='BIE1':
                    printError(c)
                    sys.exit(-1)
            
            if chunkStart!=None:
                chunks.append( (chunkStart, c-1, chunkType) )
            
            chunks.append( (c, c, tag[2:]) )
            
            chunkStart=None
            chunkType=None
            
        else:
            if strict:
                if tag[0]!='O':
                    printError(c)
                    sys.exit(-1)
                
            if chunkStart==None:
                pass
            
            else:
                if strict:
                    if mode=='BIE1':
                        printError(c)
                        sys.exit(-1)
                
                chunks.append( (chunkStart, c-1, chunkType) )
                chunkStart=None
                chunkType=None
        
        if c == len(taggedFile)-1 and chunkStart!=None:
            chunks.append( (chunkStart, c, chunkType) )
                
    return set(chunks)


def evaluate(chunks):
    type_corr={}
    type_auto={}
    type_gold={}
    
    aCh = chunks[1]
    gCh = chunks[0]
    
    n_goldChunks = len(gCh)
    n_autoChunks = len(aCh)
    
    corr = 0.0
    
    for chunk in aCh:	
        type=chunk[2]
        #if type != 'NP': print chunk
        try:
            type_auto[type]+=1
        except KeyError:
            type_auto[type]=1.0
            
        if chunk in gCh:# or (chunk[0], chunk[1], 'MISC') in gCh:
            corr += 1
            try:
                type_corr[type]+=1
            except KeyError:
                type_corr[type]=1.0
    
        
    
    
    for chunk in gCh:
        type = chunk[2]
        try:
            type_gold[type]+=1
        except KeyError:
            type_gold[type]=1.0
    
    types_gold=[type for type in type_gold]
    types_auto=[type for type in type_auto]
    types = types_gold+types_auto
    
    return (n_goldChunks, n_autoChunks, corr, set(types), type_gold, type_auto, type_corr)

def count(chunkCounts, tokCounts):
    results={}
    n_goldChunks, n_autoChunks, corr, types, type_gold, type_auto, type_corr = chunkCounts
    
    toks, corrToks = tokCounts
    
    results['tokens']=toks
    results['acc']=round(100*(corrToks/toks), 2)
    
    try:
        precision = corr/n_autoChunks
    except ZeroDivisionError:
        precision = 0
    
    try:
        recall = corr/n_goldChunks
    except ZeroDivisionError:
        recall = 0
    
    try:    
        fScore = (2*precision*recall)/(precision+recall)
    except ZeroDivisionError:
        fScore = 0
    
    results['goldPhrases']=n_goldChunks
    results['foundPhrases']=n_autoChunks
    results['corrPhrases']=corr
    results['All']=(100*precision, 100*recall, 100*fScore)
    
    for type in types:
        if type_auto.has_key(type):
        
            c=type_auto[type]
            if type in type_corr:
                prec = 100*(type_corr[type]/c)
                rec = 100*(type_corr[type]/type_gold[type])
                fSc = (2*prec*rec)/(prec+rec)
            
            
                results[type]=(prec,rec,fSc,c)
        
            else:
                results[type]=(0,0,0,c)
            
        else:
            results[type]=(0,0,0,0)
        
    return results
    

def printResults(results):
    tok = results['tokens']
    phras = results['goldPhrases']
    found = results['foundPhrases']
    corr = results['corrPhrases']
    acc = results['acc']
    prec = results['All'][0]
    rec = results['All'][1]
    fb = results['All'][2]
    
    print 'processed %.0f tokens with %.0f phrases;' % (tok, phras),
    print 'found: %.0f phrases; correct: %.0f.\n' % (found, corr),
    
    print 'accuracy: %6.2f%%; precision: %6.2f%%;' % (acc, prec),
    print 'recall: %6.2f%%; FB1: %6.2f' % (rec, fb)
    
    sortedTypes=results.keys()
    sortedTypes.sort()
    
    for type in sortedTypes:
        if type in ('corrPhrases', 'tokens', 'foundPhrases', 'acc', 'All', 'goldPhrases'):
            continue
        
        prec = results[type][0]
        rec = results[type][1]
        fb = results[type][2]
        found = int(results[type][3])
        
        print '%17s:' % (type),
        print 'precision: %6.2f%%;' % prec,
        print 'recall: %6.2f%%;' % rec,
        print 'FB1: %6.2f    %.0f' % (fb, found)

def analyzeErrors(chunks):
    gCh=chunks[0]
    aCh=chunks[1]

    errorTypes={'wrongType':0, 'disjunct':0, 'overlap':0, 'aInG':0, 'gInA':0}
    wrongCategory={}
    for chunk in aCh:
        if chunk in gCh:
            continue
        
        nearest=getNearestChunk(chunk, gCh)
        
        if chunk[2]!=nearest[2]:
            errorTypes['wrongType']+=1
            
            try:
                wrongCategory[(nearest[2], chunk[2])]+=1
            except KeyError:
                wrongCategory[(nearest[2], chunk[2])]=1
            
            continue
                    
        gPos=(nearest[0], nearest[1])
        aPos=(chunk[0], chunk[1])
        
        rel = compareChunks(gPos, aPos)
        errorTypes[rel]+=1        

    allErrors=0
    for type in errorTypes:
        allErrors+=errorTypes[type]
    
    return (errorTypes, allErrors, wrongCategory)

def compareChunks(gPos, aPos):
    if aPos[0] < gPos[0]:
        if aPos[1] < gPos[0]:
            return 'disjunct'
        elif aPos[1] < gPos[1]:
            return 'overlap'
        else:
            return 'gInA'
        
    elif aPos[0]==gPos[0]:
        if aPos[1]>gPos[1]:
            return 'gInA'
        else:
            return 'aInG' 
    
    else:
        if gPos[1] < aPos[0]:
             return 'disjunct'
        elif gPos[1] < aPos[1]:
             return 'overlap'
        else:
             return 'aInG'


def getNearestChunk(chunk, gCh):
    chSum=chunk[0]+chunk[1]
    nearest=None
    bestDist=999
    for currChunk in gCh:
        currSum=currChunk[0]+currChunk[1]
        if abs(currSum-chSum)<bestDist:
            bestDist=abs(currSum-chSum)
            nearest=currChunk

    return nearest

def printErrorTypes(errorTypes):
    for type in errorTypes:
        print type+':\t'+str(errorTypes[type])
    

def getChunkPatterns(taggedFile, goldField, autoField, mode):
    if mode != 'BIE1':
        sys.stderr.write('extracting chunk patterns is currently not available in \'BI\' mode!')
        sys.exit(-1)
    
    
    patternCount={}
    
    prevTags = ('O', 'O')
    pattern=[]
    curr=''
    G=False
    A=False
    for c,l in enumerate(taggedFile):
        if l == '\n':
            if len(pattern)!=0:
                wPatt = ' '.join(pattern)
                #print wPatt
                try:
                    patternCount[wPatt]+=1
                except KeyError:
                    patternCount[wPatt]=1
                
                pattern=[]
            
            continue
            
        line = l.split()
        goldTag = line[goldField][0]
        autoTag = line[autoField][0]

        
    
        if goldTag in ('B', 'E'):
            curr+='g'
            G=False
        elif goldTag == '1':
            curr+='gg'
            G=False
        elif goldTag == 'O':
            if not G:
                curr+='G'
                G = True
        
        else:
            assert goldTag == 'I'
                
        if autoTag in ('B', 'E'):
            curr+='a'
            A=False
        elif autoTag == '1':
            curr+='aa'
            A=False
        elif autoTag == 'O':
            if not A:
                curr+='A'
                A = True
            
        else:
            assert autoTag == 'I' 
            
                        
        if curr not in ('', 'GA'):
            pattern.append(curr)
        
        curr=''
        
        if goldTag==autoTag:
            
            if autoTag in ('I', 'B'):
                pass
                
            else:
                if autoTag == 'O' and pattern==[]:
                    continue
                    
                wPatt = ' '.join(pattern)
                #print wPatt
                try:
                    patternCount[wPatt]+=1
                except KeyError:
                    patternCount[wPatt]=1
                
                pattern=[] 
    
    return patternCount    


def printPatterns(patternCount):
    
    #sys.stderr.write('1')
    
    all = 0.0
    patterns=[]
    
    """
    patternsToAscii = {'ga ga': '---_____---\n---_____---',
                                         'ggaa' : '---_---\n---_---',
                                         'ga g g ga': '---_____---_____---\n---_____________---',
                                         'ga aa ga': '---_______---\n---___|___---',
                                         'ga g g g g ga': '---_____---_____---_____---\n---_____________________---',
                                         'ga a a ga': '---_____________---\n--_____---_____---',
                                         'a g ga':'-------_____---\n---_________---',
                                         'g a ga':'---_________---\n-------_____---'}
    """                                     
    
    for patt in patternCount:
        if patt in ('ga ga', 'ggaa', 'GA'):
            continue
        all+= patternCount[patt]
        patterns.append( (patternCount[patt], patt) )
        
    patterns.sort()
    patterns.reverse()
    for patt in patterns:
        percent = (patternCount[patt[1]]/all)*100
        print str(percent)[:6]+'%\t'+patt[1]+'\n'+patternsToAscii(patt[1])

def patternsToAscii(patt):
    #print patt
    asciiPattG=''
    asciiPattA=''
    for pos in patt.split():
        if 'gg' in pos: 
            asciiPattG+='1'
        elif 'g' in pos:
            asciiPattG+='|'
        
        elif 'G' in pos:
            asciiPattG+=' '
        
        else:    
            asciiPattG+='-'    
            
        if 'aa' in pos: 
            asciiPattA+='1'
        elif 'a' in pos:
            asciiPattA+='|'
        
        elif 'A' in pos:
            asciiPattA+=' '
        
        else:    
            asciiPattA+='-'
        
        
    return asciiPattG+'\n'+asciiPattA    


def countToks(taggedFile, goldField, autoField):
    toks=0.0
    corrToks=0.0
    #print taggedFile
    for line in taggedFile:
        #print '@'
        if line == '\n': continue
        
        toks+=1
        if line.split()[goldField]==line.split()[autoField]:
            corrToks+=1
    
    
    return (toks, corrToks)
    
def printConfMatrix(wrongCategory):
    pairs=wrongCategory.keys()
    pairs.sort()
    for pair in pairs:
        print '%s --> %s %.0f' % (pair[0], pair[1], wrongCategory[pair]) 


def leaveInternalBs(corp):
    newS=True
    newCorp=[]
    for line in corp:
        if line == '\n':
            newCorp.append('\n')
            newS=True
            continue
        else:
            l=line.split()
            newL=l[:-2]
            if l[-2][0]=='B' and not newS:
                newL+=['B'+l[-2][1:]]
            else:
                newL+=['O']
            
            if l[-1][0]=='B' and not newS:
                newL+=['B'+l[-1][1:]]
            else:
                newL+=['O']

            assert len(newL) == len (l)
            newCorp.append('\t'.join(newL))                     
            newS=False
    return newCorp

def getSenPrec(corp):
    allSen=0.0
    corrSen=0.0
    thisSen=True
    for line in corp:
        if line == '\n':
            allSen+=1
            if thisSen:
                corrSen+=1
            thisSen=True
            
        else:
            if line.split()[-2]!=line.split()[-1]:
                thisSen=False
    print 'sentence precision:', str((corrSen/allSen)*100)[:5]+'%'

def runEval(stdin, goldField='-2', autoField='-1', mode='BI', conll=False, bPoints=False, sen=False, strict=False, pattern=False):
    
    """
    try:
        taggedFile = sys.stdin.readlines()
        goldField = int(sys.argv[1])
        autoField = int(sys.argv[2])
        mode = sys.argv[3]
        assert mode in ('BI', 'BIE1')
    except:
        sys.stderr.write( 'usage: eval.py <gold field> <auto field> <mode>\n' )
        sys.exit(-1) 
    """
    
    tF=stdin.readlines()
    chunks = getChunksFromCorp(tF, int(goldField), int(autoField), mode, strict)
    
    chunkCounts = evaluate(chunks)
    #print tF
    tokCounts = countToks(tF, int(goldField), int(autoField))
    #print tokCounts
    results = count(chunkCounts, tokCounts)
    
    #if conll==True:
    if conll==True:
        printResults(results)
        
    if bPoints:
        newCorp=leaveInternalBs(tF)
        #print '\n'.join(newCorp)
        newChunks=getChunksFromCorp(newCorp, int(goldField), int(autoField), 'BI', strict)
        newChunkCounts=evaluate(newChunks)
        newTokCounts=countToks(newCorp, int(goldField), int(autoField))
        newResults = count (newChunkCounts, newTokCounts)
        printResults(newResults)
    #errorTypes, allErrors, wrongCategory=analyzeErrors(chunks)
    #print str(wrongCategory)
    
    if sen:
        getSenPrec(tF)
    
    #assert allErrors+results[2]==results[1]
    #printErrorTypes(errorTypes)
    #printConfMatrix(wrongCategory)
    if pattern:
        patternCount = getChunkPatterns(tF, int(goldField), int(autoField), mode)
    
        printPatterns(patternCount)

def evalInput(input, autoField=-1, goldField=-2):
    corp = []
    for sen in input:
        for tok in sen:
            corp.append('\t'.join(tok)+'\n')
        corp.append('\n')
    
    chunks = getChunksFromCorp(corp, goldField, autoField, 'BI', False)
    chunkCounts = evaluate(chunks)
    tokCounts = countToks(corp, goldField, autoField)
    results = count(chunkCounts, tokCounts)
    printResults(results)

if __name__ == '__main__':
    #runEval(sys.stdin, '-2', '-1', 'BIE1', False, True, True, False, False)
    optfunc.run(runEval)
back to top