朴素贝叶斯
- 创建词汇表
将文档中的新词集合添加到词汇表
123456789def createVocabList(dataSet):vocabSet = set([])for document in dataSet:# 创建两个集合的并集vocabSet = vocabSet | set(document)return list(vocabSet) - 词集模型(set-of-words model)
通过词汇表将文档转换为文档向量, 向量的每个元素表示词汇表中的单词在输入文档中是否出现
123456789101112131415def setOfWords2Vec(vocabList, inputSet):# 创建一个和词汇表等长, 所含元素都为0的向量returnVec = [0] * len(vocabList)# 遍历文档中所有单词for word in inputSet:# 如果出现了词汇表中的单词if word in vocabList:# 将输出的文档向量中的对应值设为1returnVec[vocabList.index(word)] = 1else:print("the word: %s is not in my Vocabulary!" % word)return returnVec - 词袋模型(bag-of-words model)
文档向量的每个函数表示词汇表中的单词在文档中出现的次数
123456def bagOfWords2VecMN(vocabList, inputSet):returnVec = [0] * len(vocabList)for word in inputSet:if word in vocabList:returnVec[vocabList.index(word)] += 1return returnVec - 使用正则表达式切分文本
捕获所有单词, 去掉少于两个字母的字符串, 并将所有字符串转换为小写
1234def textParse(bigString):import relistOfTokens = re.split('\w+', bigString)return [tok.lower() for tok in listOfTokens if len(tok) > 2]
p.s. 书中使用的正则表达式为 r'\W*'
, 运行时出现警告 split() requires a non-empty pattern match
. 官方文档如下:
Note: split()
doesn’t currently split a string on an empty pattern match. For example:
1 2 |
>>> re.split('x*', 'axbc') ['a', 'bc'] |
Even though ‘x*’ also matches 0 ‘x’ before ‘a’, between ‘b’ and ‘c’, and after ‘c’, currently these matches are ignored. The correct behavior (i.e. splitting on empty matches too and returning [”, ‘a’, ‘b’, ‘c’, ”]) will be implemented in future versions of Python, but since this is a backward incompatible change, a FutureWarning will be raised in the meanwhile.
Patterns that can only match empty strings currently never split the string. Since this doesn’t match the expected behavior, a ValueError will be raised starting from Python 3.5:
1 2 3 4 5 |
>>> re.split("^$", "foo\n\nbar\n", flags=re.M) Traceback (most recent call last): File "<stdin>", line 1, in <module> ... ValueError: split() requires a non-empty pattern match. |
训练函数和分类函数
- 训练函数
123456789101112131415161718192021222324252627def trainNB0(trainMatrix, trainCategory):numTrainDocs = len(trainMatrix)numWords = len(trainMatrix[0])# 任意文档属于侮辱性文档的概率pAbusive = sum(trainCategory)/float(numTrainDocs)# 初始化概率p0Num = ones(numWords)p1Num = ones(numWords)p0Denom = 2.0p1Denom = 2.0# 遍历训练集中所有文档for i in range(numTrainDocs):if trainCategory[i] == 1:# 向量相加p1Num += trainMatrix[i]p1Denom += sum(trainMatrix[i])else:p0Num += trainMatrix[i]p0Denom += sum(trainMatrix[i])# 对每个元素做除法, 并取自然对数避免下溢出p1Vect = log(p1Num/p1Denom)p0Vect = log(p0Num/p1Denom)return p0Vect, p1Vect, pAbusive
- 分类函数
1234567def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):p1 = sum(vec2Classify * p1Vec) + log(pClass1)p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)if p1 > p0:return 1else:return 0
垃圾邮件测试函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
def spamTest(): docList = [] classList = [] fullText = [] # 导入文件夹 spam 和 ham 下的文本文件, 并解析为词列表 for i in range(1, 26): # spam wordList = textParse(open('email/spam/%d.txt' % i).read()) docList.append(wordList) fullText.extend(wordList) classList.append(1) # ham wordList = textParse(open('email/ham/%d.txt' % i).read()) docList.append(wordList) fullText.extend(wordList) classList.append(0) # 创建词汇表 vocabList = bayes.createVocabList(docList) # 创建训练集 trainingSet = range(50) # 创建测试集 testSet = [] # 留存交叉验证(hold-out cross validation): 随机选取数据的一部分作为训练集, 剩余部分作为测试集 for i in range(10): # 随机选出邮件 randIndex = int(random.uniform(1, len(trainingSet))) # 添加到测试集 testSet.append(trainingSet[randIndex]) # 从训练集中删除 del(list(trainingSet)[randIndex]) trainMat = [] trainClasses = [] # 遍历训练集所有文档 for docIndex in trainingSet: # 基于词汇表构建词向量 trainMat.append(bayes.setOfWords2Vec(vocabList, docList[docIndex])) trainClasses.append(classList[docIndex]) p0V, p1V, pSpam = bayes.trainNB0(array(trainMat), array(trainClasses)) errorCount = 0 for docIndex in testSet: wordVector = bayes.setOfWords2Vec(vocabList, docList[docIndex]) if classifyNB(array(wordVector), p0V, p1V, pSpam) != classList[docIndex]: errorCount += 1 print('the error rate is: ', float(errorCount)/len(testSet)) |
p.s. 官方提供的测试数据中有非法字符