python下的ahocorasick实现快速的关键字匹配

这两天在折腾下数据的分析及导出,爬虫抓取页面的时候,我们会坐做关键字的匹配,在数据库中标记这个url是否有我们需要的关键字。 这个时候你不能再用find()了,这太没有效率了,而且你会发现在同时处理几千个任务的时候,会出现cpu的瓶颈。 如果采用ahocorasick来实现,可以很有效的减轻cpu的消耗。 

AC自动机是多模式匹配的一个经典数据结构,原理是和KMP一样的构造fail指针,不过AC自动机是在Trie树上构造的,但原理是一样的。我这里就不多说ac的原理了,有兴趣的朋友可以自己找自己瞅瞅。 


python下已经有个现成的库了 ,大家直接 pip install ahocorasick .

import ahocorasick
tree = ahocorasick.KeywordTree()
tree.add("alpha")          
tree.add("alpha beta")
tree.add("gamma")

tree.make()

tree.search("I went to alpha beta the other day to pick up some spam")
(10, 15)
tree.search_long("I went to alpha beta the other day to pick up some spam")
(10, 20)
tree.search("and also got some alphabet soup")
(18, 23)
tree.search("but no waffles")

tree.search_long("oh, gamma rays are not tasty")
(4, 9)

tree.findall("I went to alpha beta to pick up alphabet soup")

for match in tree.findall("I went to alpha beta to pick up alphabet soup"):
...     print match
... 
(10, 15)
(32, 37)


如果不想用现成的ac的库,可以自己照着ac的原理写一份,只是过程很痛苦,我还是喜欢拿来主意 。毕竟实现算法的过程是个十分xxx的事情。 

#coding=utf-8

KIND = 16
#BASE = ord('a')

class Node():
    static = 0
    def __init__(self):
        self.fail = None
        self.next = [None]*KIND
        self.end = False
        self.word = None
        Node.static += 1

class AcAutomation():
    def __init__(self):
        self.root = Node()
        self.queue = []
        
    def getIndex(self,char):
        return ord(char)# - BASE
    
    def insert(self,string):
        p = self.root
        for char in string:
            index = self.getIndex(char)
            if p.next[index] == None:
                p.next[index] = Node()
            p = p.next[index]
        p.end = True
        p.word = string
        
    def build_automation(self):
        self.root.fail = None
        self.queue.append(self.root)
        while len(self.queue)!=0:
            parent = self.queue[0]
            self.queue.pop(0)
            for i,child in enumerate(parent.next):
                if child == None:continue
                if parent == self.root:
                    child.fail = self.root
                else:
                    failp = parent.fail
                    while failp != None:
                        if failp.next[i] != None:
                            child.fail = failp.next[i]
                            break
                        failp = failp.fail
                    if failp==None: child.fail=self.root
                self.queue.append(child)
                
    def matchOne(self,string):
        p = self.root
        for char in string:
            index = self.getIndex(char)
            while p.next[index]==None and p!=self.root: p=p.fail
            if p.next[index]==None:p=self.root
            else: p=p.next[index]
            if p.end:return True,p.word
        return False,None
    
    


class UnicodeAcAutomation():
    def __init__(self,encoding='utf-8'):
        self.ac = AcAutomation()
        self.encoding = encoding
        
    def getAcString(self,string):
        string = bytearray(string.encode(self.encoding))
        ac_string = ''
        for byte in string:
            ac_string += chr(byte%16)
            ac_string += chr(byte/16)
        #print ac_string
        return ac_string
    
    def insert(self,string):
        if type(string) != unicode:
            raise Exception('UnicodeAcAutomation:: insert type not unicode')
        ac_string = self.getAcString(string)
        self.ac.insert(ac_string)

    def build_automation(self):
        self.ac.build_automation()
    
    def matchOne(self,string):
        if type(string) != unicode:
            raise Exception('UnicodeAcAutomation:: insert type not unicode')
        ac_string = self.getAcString(string)
        retcode,ret = self.ac.matchOne(ac_string)
        if ret!=None:
            s = ''
            for i in range(len(ret)/2):
                s += chr(ord(ret[2*i])+ord(ret[2*i+1])*16)
            ret = s.decode('utf-8')
        return retcode,ret
    


def main2():
    ac = UnicodeAcAutomation()
    ac.insert(u'丁亚光')
    ac.insert(u'好吃的')
    ac.insert(u'好玩的')
    ac.build_automation()
    print ac.matchOne(u'hi,丁亚光在干啥')
    print ac.matchOne(u'ab')
    print ac.matchOne(u'不能吃饭啊')
    print ac.matchOne(u'饭很好吃,有很多好好的吃的,')
    print ac.matchOne(u'有很多好玩的')

if __name__ == '__main__':
    main2()
    
    


有个网友已经把ac自动机和redis做了一些联合 。

#coding: utf-8

'''
使用redis的ac算法
'''
import redis


def smart_unicode(s, encoding='utf-8'):
    ret = s
    if type(s) is str:
        ret = s.decode(encoding)
    return ret

def smart_str(s, encoding='utf-8'):
    ret = s
    if type(s) is unicode:
        ret = s.encode(encoding)
    return ret

class RedisACKeywords(object):
    '''
    (1) Efficient String Matching: An Aid to Bibliographic Search
    (2) Construction of Aho Corasick Automaton in Linear Time for Integer Alphabets
    '''
    # %s is name
    KEYWORD_KEY=u'%s:keyword'
    PREFIX_KEY=u'%s:prefix'
    SUFFIX_KEY=u'%s:suffix'

    # %s is keyword
    OUTPIUT_KEY=u'%s:output'
    NODE_KEY=u'%s:node'

    def __init__(self, host='localhost', port=6379, db=12, name='RedisACKeywords', encoding='utf8'):
        '''
        db: 7+5 because 1975
        '''
        self.encoding = encoding

        self.client = redis.Redis(host=host, port=port, db=db)
        self.client.ping()

        self.name = smart_unicode(name)

        # Init trie root
        self.client.zadd(self.PREFIX_KEY % self.name, u'', 1.0)


    def add(self, keyword):
        keyword = keyword.strip().lower()
        assert keyword
        keyword = smart_unicode(keyword)

        # Add keyword in keyword set
        self.client.sadd(self.KEYWORD_KEY % self.name, keyword)

        self._build_trie(keyword)

        num = self.client.scard(self.KEYWORD_KEY % self.name)
        return num

    def remove(self, keyword):
        assert keyword
        keyword = keyword.strip().lower()
        keyword = smart_unicode(keyword)

        self._remove(keyword)

        self.client.srem(self.KEYWORD_KEY % self.name, keyword)
        num = self.client.scard(self.KEYWORD_KEY % self.name)
        return num

    def find(self, text):
        ret = []
        i = 0
        state = u''
        utext = smart_unicode(text)
        while i < len(utext):
            c = utext[i]
            next_state = self._go(state, c)
            if next_state is None:
                next_state = self._fail(state + c)
                ####################
                ## the above line likes take same effect as this block
                #next_state = self._fail(state)
                #_next_state = self._go(next_state, c)
                #if _next_state is None:
                #    _next_state = self._fail(next_state + c)
                #next_state = _next_state
                ######################

            outputs = self._output(state)
            ret.extend(outputs)

            state = next_state
            i += 1

        # check last state
        outputs = self._output(state)
        ret.extend(outputs)
        return ret

    def flush(self):
        keywords = self.client.smembers(self.KEYWORD_KEY % self.name)
        for keyword in keywords:
            self.client.delete(self.OUTPIUT_KEY % smart_unicode(keyword))
            self.client.delete(self.NODE_KEY % smart_unicode(keyword))
        self.client.delete(self.PREFIX_KEY % self.name)
        self.client.delete(self.SUFFIX_KEY % self.name)
        self.client.delete(self.KEYWORD_KEY % self.name)

    def info(self):
        return {
            'keywords':self.client.scard(self.KEYWORD_KEY % self.name),
            'nodes':self.client.zcard(self.PREFIX_KEY % self.name),
        }

    def suggest(self, input):
        input = smart_unicode(input)
        ret = []
        rank = self.client.zrank(self.PREFIX_KEY % self.name, input)
        a = self.client.zrange(self.PREFIX_KEY % self.name, rank, rank)
        while a:
            node = smart_unicode(a[0])
            if node.startswith(input) and self.client.sismember(self.KEYWORD_KEY % self.name, node):
                ret.append(node)
            rank += 1
            a = self.client.zrange(self.PREFIX_KEY % self.name, rank, rank)
        return ret

    def _go(self, state, c):
        '''
        转向函数
        '''
        assert type(state) is unicode
        next_state = state + c
        i = self.client.zscore(self.PREFIX_KEY % self.name, next_state)
        if i is None:
            return None
        return next_state

    def _build_trie(self, keyword):
        assert type(keyword) is unicode
        l = len(keyword)
        for i in xrange(l): # trie depth increase
            prefix = keyword[:i+1] # every prefix is a node
            _suffix = u''.join(reversed(prefix))
            if self.client.zscore(self.PREFIX_KEY % self.name, prefix) is None: # node does not exist
                self.client.zadd(self.PREFIX_KEY % self.name, prefix, 1.0)
                self.client.zadd(self.SUFFIX_KEY % self.name, _suffix, 1.0) # reversed suffix node

                self._rebuild_output(_suffix)
            else:
                if (self.client.sismember(self.KEYWORD_KEY % self.name, prefix)): # node may change, rebuild affected nodes
                    self._rebuild_output(_suffix)

    def _rebuild_output(self, _suffix):
        assert type(_suffix) is unicode
        rank = self.client.zrank(self.SUFFIX_KEY % self.name, _suffix)
        a = self.client.zrange(self.SUFFIX_KEY % self.name, rank, rank)
        while a:
            suffix_ = smart_unicode(a[0])
            if suffix_.startswith(_suffix):
                state = u''.join(reversed(suffix_))
                self._build_output(state)
            else:
                break
            rank += 1 # TODO: Binary search?
            a = self.client.zrange(self.SUFFIX_KEY % self.name, rank, rank)

    def _build_output(self, state):
        assert type(state) is unicode
        outputs = []
        if self.client.sismember(self.KEYWORD_KEY % self.name, state):
            outputs.append(state)
        fail_state = self._fail(state)
        fail_output = self._output(fail_state)
        if fail_output:
            outputs.extend(fail_output)
        if outputs:
            self.client.sadd(self.OUTPIUT_KEY % state, *outputs)
            for k in outputs:
                self.client.sadd(self.NODE_KEY % k, state) # ref node for delete keywords in output

    def _fail(self, state):
        '''
        失败函数
        '''
        assert type(state) is unicode
        # max suffix node will be the failed node
        next_state = u''
        for i in xrange(1, len(state)+1): # depth increase
            next_state = state[i:]
            if self.client.zscore(self.PREFIX_KEY % self.name, next_state):
                break
        return next_state

    def _output(self, state):
        '''
        输出函数
        '''
        assert type(state) is unicode
        return [smart_unicode(k) for k in self.client.smembers(self.OUTPIUT_KEY % state)]

    def debug_print(self):
        keywords = self.client.smembers(self.KEYWORD_KEY % self.name)
        if keywords:
            print '-',  self.KEYWORD_KEY % self.name, u' '.join(keywords)
        prefix = self.client.zrange(self.PREFIX_KEY % self.name, 0, -1)
        if prefix:
            prefix[0] = u'.'
            print '-',  self.PREFIX_KEY % self.name, u' '.join(prefix)
        suffix = self.client.zrange(self.SUFFIX_KEY % self.name, 0, -1)
        if suffix:
            print '-',  self.SUFFIX_KEY % self.name, u' '.join(suffix)

        outputs = []
        for node in prefix:
            output = self._output(smart_unicode(node))
            outputs.append({node: output})
        if outputs:
            print '-',  'outputs', outputs

        nodes = []
        for keyword in keywords:
            keyword_nodes = self.client.smembers(self.NODE_KEY % smart_unicode(keyword))
            nodes.append({keyword: keyword_nodes})
        if nodes:
            print '-', 'nodes', nodes

    def _remove(self, keyword):
        assert type(keyword) is unicode
        nodes = self.client.smembers(self.NODE_KEY % keyword)
        for node in nodes:
            self.client.srem(self.OUTPIUT_KEY % smart_unicode(node), keyword)
        self.client.delete(self.NODE_KEY % keyword)

        # remove nodes if need
        l = len(keyword)
        for i in xrange(l, 0, -1): # depth decrease
            prefix = keyword[:i]
            if self.client.sismember(self.KEYWORD_KEY % self.name, prefix) and i!=l:
                break
            _suffix = u''.join(reversed(prefix))

            rank = self.client.zrank(self.PREFIX_KEY % self.name, prefix)
            if rank is None:
                break
            a = self.client.zrange(self.PREFIX_KEY % self.name, rank+1, rank+1)
            if a:
                prefix_ = smart_unicode(a[0])
                if not prefix_.startswith(prefix):
                    self.client.zrem(self.PREFIX_KEY % self.name, prefix)
                    self.client.zrem(self.SUFFIX_KEY % self.name, _suffix)
                else:
                    break
            else:
                self.client.zrem(self.PREFIX_KEY % self.name, prefix)
                self.client.zrem(self.SUFFIX_KEY % self.name, _suffix)

if __name__ == '__main__':
    acs = RedisACKeywords(name='test3')
    ks = ['aabbc', 'abbc', 'bbb']
    for k in ks:
        acs.add(k)

    print 'find result in aaabbbccc is :', acs.find('aaabbbccc')
    acs.flush()



    keywords = RedisACKeywords(name='test2')

    ks = ['her', 'he', 'his', 'here', 'there']
    for k in ks:
        keywords.add(k)
        keywords.debug_print()
        print '>>>>>>>>>>>>'

    text = 'here'
    print 'text: %s' % text
    print 'keywords: %s added. ' % ' '.join(ks), keywords.find(text) # her, he

    input = 'he'
    print 'suggest %s: ' % input, keywords.suggest(input) # her, he

    text = 'ushers'
    print 'text: %s' % text
    print 'keywords: %s added. ' % ' '.join(ks), keywords.find(text) # her, he

    ks2 = ['she', 'hers']
    for k in ks2:
        keywords.add(k)
    print 'keywords: %s added. ' % ' '.join(ks2), keywords.find(text) # her, he, she, hers

    keywords.add('h')
    print 'h added. ', keywords.find(text) # her, he, she, hers, h

    keywords.remove('h')
    print 'h removed. ', keywords.find(text) # her, he, she, hers

    keywords.flush()
    print 'flushed. ', keywords.find(text) # []


大家觉得文章对你有些作用! 如果想赏钱,可以用微信扫描下面的二维码,感谢!
另外再次标注博客原地址  xiaorui.cc

1 Response

  1. orangleliu 2014年10月31日 / 下午5:32

    KMP 这些都开始研究啦,不错

发表评论

邮箱地址不会被公开。 必填项已用*标注