参考与主要要点:
- https://cloud.tencent.com/developer/article/1387666 实现
- https://github.com/keras-team/keras-contrib crf的keras库
- https://docs.floydhub.com/guides/environments/ 注意tensorflow和keras版本问题
- https://github.com/jiesutd/LatticeLSTM 数据下载地址
数据样例
中 B-LOC国 E-LOC的 O天 B-PER安 I-PER门 E-PER我 O跟 O他 O谈 O笑 O风 O生 O
代码
数据处理
import numpyfrom collections import Counterfrom keras.preprocessing.sequence import pad_sequencesimport pickleimport platformimport syssys.path.append('/tf/keras/keras-contrib')def _parse_data(fh):# in windows the new line is '\r\n\r\n' the space is '\r\n' . so if you use windows system,# you have to use recorsponding instructionsif platform.system() == 'Windows':# split_text = '\r\n' #linuxsplit_text = '\n' #windowselse:split_text = '\n'raw_Cropat = fh.read().decode('utf-8')data = [[row.split() for row in sample.split(split_text)] forsample inraw_Cropat.strip().split(split_text + split_text)]fh.close()return datadef _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False):if maxlen is None:maxlen = max(len(s) for s in data)word2idx = {w:i for i, w in enumerate(vocab)}x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to <unk> (index 1) if not in vocaby_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data]x = pad_sequences(x, maxlen) # left paddingy_chunk = pad_sequences(y_chunk, maxlen, value=-1)if onehot:y_chunk = numpy.eye(len(chunk_tags), dtype='float32')[y_chunk]else:y_chunk = numpy.expand_dims(y_chunk, 2)return x, y_chunkdef process_data(data, vocab, maxlen=100):word2idx = dict((w, i) for i, w in enumerate(vocab))x = [word2idx.get(w[0].lower(), 1) for w in data]length = len(x)x = pad_sequences([x], maxlen) # left paddingreturn x, lengthdef load_data():train = _parse_data(open('demo.train.char', 'rb'))test = _parse_data(open('demo.test.char', 'rb'))word_counts = Counter(row[0].lower() for sample in train for row in sample)vocab = [w for w, f in iter(word_counts.items()) if f >= 2]chunk_tags = list(set([ line[1] for oneDev in train for line in oneDev]))# save initial config datawith open('config.pkl', 'wb') as outp:pickle.dump((vocab, chunk_tags), outp)train = _process_data(train, vocab, chunk_tags)test = _process_data(test, vocab, chunk_tags)return train, test, (vocab, chunk_tags)
搭建模型
from keras.models import Sequentialfrom keras.layers import Embedding, Bidirectional, LSTMfrom keras_contrib.layers import CRFimport pickleEMBED_DIM = 200BiRNN_UNITS = 200def create_model(train=True):if train:(train_x, train_y), (test_x, test_y), (vocab, chunk_tags) = load_data()else:with open('model/config.pkl', 'rb') as inp:(vocab, chunk_tags) = pickle.load(inp)model = Sequential()model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embeddingmodel.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))crf = CRF(len(chunk_tags), sparse_target=True)model.add(crf)model.summary()model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])if train:return model, (train_x, train_y), (test_x, test_y)else:return model, (vocab, chunk_tags)if __name__=="__main__":EPOCHS = 10model, (train_x, train_y), (test_x, test_y) = create_model()# train modelmodel.fit(train_x, train_y,batch_size=16,epochs=EPOCHS, validation_data=[test_x, test_y])model.save('model/crf.h5')
预测
import numpy as npwith open('config.pkl', 'rb') as inp:(vocab, chunk_tags) = pickle.load(inp)predict_text = '中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下,连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚'sequence, length = process_data(predict_text, vocab)model.load_weights('crf.h5')raw = model.predict(sequence)[0][-length:]result = [np.argmax(row) for row in raw]result_tags = [chunk_tags[i] for i in result]per, loc, org = '', '', ''for s, t in zip(predict_text, result_tags):if t in ('B-PER', 'M-PER', 'S-PER','E-PER'):per += ' ' + s if (t == 'B-PER') else sif t in ('B-ORG', 'M-ORG', 'S-ORG','E-ORG'):org += ' ' + s if (t == 'B-ORG') else sif t in ('B-LOC', 'M-LOC', 'S-LOC','E-LOC'):loc += ' ' + s if (t == 'B-LOC') else sprint(['person:' + per, 'location:' + loc, 'organzation:' + org])
报错处理
- TypeError: Tensors in list passed to ‘values’ of ‘ConcatV2’ Op have types [bool, float32] that don’t all match.
删除:Embedding层的 mask_zero=True
mask_zero: 是否把 0 看作为一个应该被遮蔽的特殊的 “padding” 值。 这对于可变长的循环神经网络层 十分有用。 如果设定为 True,那么接下来的所有层都必须支持 masking,否则就会抛出异常。 如果 mask_zero 为 True,作为结果,索引 0 就不能被用于词汇表中 (input_dim 应该与 vocabulary + 1 大小相同)。
