【3.0】决策树
前言
如果你以前没有接触过决策树,也不需要担心,它的概念非常简单。即使不知道它也可以通过简单的图形了解其中的工作原理,下图的流程图就是一个决策树,长方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),表示这已经得出结论,可以终止运行。从判断模块引出的左右箭头称为分支(branch),它可以到达另一个判断模块或者终止模块。该流程图构造了一个假想的邮件分类系统,它首先检测发生邮件域名地址。如果地址为 myEmployer.com ,则将其放在分类 “无聊时需要阅读的邮件”,其他同理分类。
K-近邻算法已经可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内在含义,决策树的主要优势就是在于数据形式非常容易理解。决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程,就是机器学习的过程。
前排提醒:在接下来的代码示例中,有看不懂的函数,可以尝试在下面的 函数相关说明 处查看
决策树的构造
决策树的优缺点
- 优点:计算复杂度不高,输出结果容易理解,对中间值的缺失并不敏感,可以处理不相关特征数据。
- 缺点:可能会产生过度匹配的问题
- 适用数据类型:数值型和标称型。
首先,我们讨论数学上如何使用信息论划分数据集,然后编写代码将理论应用到具体的数据集上,最后编写代码构建决策树。
在构造决策树时,我们需要解决的第一个问题:当前数据集上哪个特征在划分数据分类时起决定性作用。 为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则不需要再分割了。如果某分支下的数据不属于同一类型,则需要重复划分数据子集,直到所有相同类型的数据被划分为各自的子集中。
创建分支的伪代码如下所示:
1 | 检测数据集中的每个子项是否属于同一分类 |
这个伪代码函数是一个递归函数,后续我们会使用python
代码来实现这段伪代码。一些决策树采用二分法划分数据,本文并不采用这种方法。本文将使用 ID3 算法划分数据集。每次划分数据集时,我们只选取一个特征属性,如果训练集中存在 20 个特征,第一次我们选择哪个特征作为划分的参考属性呢?。
一些常见的决策树算法:
- ID3(Iterative Dichotomiser 3):ID3 是最早的决策树算法之一,它使用信息增益来选择最优的特征进行分裂。然而,ID3 倾向于选择具有更多取值的特征,因此在实践中往往使用其他算法。
- C4.5:C4.5 是 ID3 的改进版本,它使用信息增益比来选择最优的特征。相对于 ID3,C4.5 能够处理连续特征和缺失数据,并且可以生成具有更好泛化能力的决策树。
- CART(Classification and Regression Trees):CART 是一种常用的决策树算法,可以用于分类和回归问题。CART 使用基尼系数(Gini Index)来选择最优的特征进行分裂,它生成的决策树是二叉树结构。
- CHAID(Chi-squared Automatic Interaction Detection):CHAID 是一种基于卡方检验的决策树算法,适用于分类问题。它可以处理离散和连续特征,并且能够检测特征之间的交互作用。
- Random Forest(随机森林):随机森林是一种集成学习方法,基于多个决策树进行预测。每个决策树都是通过随机选择样本和特征进行训练的,最后的预测结果由多个决策树的投票或平均值得出。
- Gradient Boosting Trees(梯度提升树):梯度提升树也是一种集成学习方法,通过迭代地训练决策树来提高预测性能。每个决策树都是在前一棵树的残差基础上进行训练的,最终的预测结果是多个决策树的加权和。
信息增益
划分数据集的大原则是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。
在划分数据集之前之后信息发生的变化称为:信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
在评测哪种数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德·艾尔伍德·香农。
如果看不明白什么是信息增益和熵,也不需要着急——它们自诞生的那一天起,就注定令人费解。
熵的定义为信息的期望值,在明晰这个概念之前,我们必须知道信息的定义。如果待分类的事务可能划分在多个分类之中,则符合 $x_i$ 的信息定义为:$\large l(x_i) = - log_2{p(x_i)}$ ,其中 $\large p(x_i)$ 是选择该分类的概率。
为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
$\large H = - \sum^{n}_{i=1}p(x_i) \log_2{p(x_i)}$,其中 $n$ 是分类的数目。
关于这两个公式的理解可以参考如何理解信息熵
下面我们将学习如何使用python
计算信息熵,创建名称为trees.py
的文件,如下代码为计算给定数据集的熵。
1 | from math import log |
现在我们使用如下数据来测试一下我们的方法:
1 | dataSet = [[1,0,'y'],[0,0,'n'],[0,1,'n'],[1,0,'y'],[0,0,'n']] |
熵越高,则混合的数据也越多,我们可以在测试数据集中添加更多的分类,观察熵的变化,现在我们增加第三个名为z
的分类,测试熵的变化:
1 | dataSet = [[1,0,'y'],[0,0,'n'],[0,1,'n'],[1,0,'y'],[0,0,'n'],[0,0,'z']] |
得到熵之后,我们就可以按照获取最大信息增益的方法划分数据集,下个部分我们将具体学习如何划分数据集以及如何度量信息增益。
另一个度量集合无序程度的方法是基尼不纯度,简单来说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。本文不采用基尼不纯度方法,这里不做更多说明。
划分数据集
上个部分我们学习了如何度量数据集的无序程度,分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断是否正确地划分了数据集。我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。
现在我们使用如下代码按照特征来划分数据集,代码示例:
1 | #划分数据集,传入数据集,特征在数据集的位置,要划分的特征 |
现在我们来测试上述代码,代码示例:
1 | dataSet = [[1,0,'y'],[0,0,'n'],[0,1,'n'],[1,0,'y'],[0,0,'n'],[0,0,'z']] |
接下来我们将会遍历整个数据集,循环计算香农熵和splitDataSet()
函数(划分数据集),找到最好的特征划分方式。熵计算会告诉我们如何划分数据集是最好的数据组织方式。
1 | # 香农熵划分最佳数据集 |
如果你实在觉得绕看不懂,可以单步调试,或者在关键的地方让它输出看看结果,多次尝试就明白了。
递归构建决策树
目前我们已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则来处理数据集。
递归的结束条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。例如下图:
第一个结束条件可以使得算法可以终止,我们甚至可以设置算法可以划分的最大分组数目。后续还会说明其他决策树算法,例如 C4.5 和 CART,这些算法在运行时并不总是在每次划分分组时都会消耗特征。由于特征数目并不是在每次划分数据时减少,因此这些算法在实际使用时候可能会引起一些问题。目前我们并不需要考虑这个问题,只需要在算法开始运行计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定改叶子节点的分类。
现在我们打开tree.py
文件,在文件头部添加import operator
,然后在文本中添加如下代码:
1 | # 标签投票分类(如果数据集已经处理了所有属性,但是类标签依然不是唯一的,我们需要决定如何定义该叶子节点) |
现在我们来在文件中添加最后的递归相关的代码:
1 | # 创建决策树,传入数据集和标签列表 |
内容很多,很抽象是吧🤣,我也觉得很抽象,理解上面的整个代码运行过程,我们来举一个实例来理解这段代码,现在我们有如下的数据,我们需要对它们进行分类。
是否有脚 | 是否有鳞片 | 是否有鳃 | 是否有尾巴 | 【特征值】 |
---|---|---|---|---|
1 | 1 | 0 | 0 | 非鱼类 |
1 | 1 | 1 | 0 | 鱼类 |
0 | 1 | 0 | 1 | 鱼类 |
1 | 0 | 0 | 0 | 非鱼类 |
0 | 1 | 0 | 0 | 鱼类 |
我们将上述数据转换成运行的python
代码如下所示:
1 | dataSet = data = [ |
很明显的看出,上面的递归代码中的labels
就是表格的表头,它是用来给每一列数据进行标注的,或者说是用来解释数据的,对于计算机来说这一列并没有参考性,但是对于我们来说是有参考意义的。
现在我们运行这段代码,运行到这里:
1 | # 获取所有数据集最后一列的数据(标签) |
我们得到classList = ['非鱼类', '鱼类', '鱼类', '非鱼类', '鱼类']
,也就说明它提取了我们的所有数据的特征。
现在运行到下面的代码部分:
1 | # 如果传入的数据集都是一个类别,就直接返回节点 |
这两个部分对应的处理就是我们前面说的递归的结束情况,代码第一个if
部分判断,如果给定的数据集类别中,第一个类别的数量等于该数据集所有类别的数量,就说明它们都是一个类别的,已经不需要分类了(递归的结束条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类)。
第二个if
部分判断的是,如果我们的整个数据集只有一列了,那就说明只剩下了最右侧的特征值列,说明已经把属性都分类完了,这个时候也不再需要分类了(如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定改叶子节点的分类)。
现在来继续往下执行代码:
1 | # 划分最佳数据集 |
现在我们将我们的数据集进行第一次划分,形象点来说就是决策树第一次分叉,执行chooseBestFeatureToSplit()
函数后,我们的输出结果:bestFeat = 0
,它告诉了我们这个数据集的第一次划分最好的属性索引是 0
,对应索引的是是否有脚
。
具体划分原理,参考前面的信息增益部分。
现在我们知道了第一次应该按什么属性来划分,代码继续运行:
1 | # 得到划分最佳数据集的标签 |
通过这个,就得到了前面我说的最佳划分属性的标签,就是是否有脚
。我们在知道第一次划分的属性后,接下来构建决策树的雏形:
1 | # 创建一个字典,以最佳特征的标签为键,值为空字典,用于构建决策树 |
现在我们创建了一个变量myTree
来存储决策树,其中它的类型是字典类型,存储了一个key = bestFeatLabel
也就是key = '是否有脚'
的key
,它对应的value
是一个空的字典,也就是代码中的{}
,这行代码的其最终的结果:myTree = {'是否有脚': {}}
。
现在代码继续执行到如下位置:
1 | # 删除已选择的最佳特征的标签,以便在递归调用时传递给下一层 |
它删除了我们标签中的第一次分叉属性,也就是由之前的['是否有脚', '是否有鳞片', '是否有鳃', '是否有尾巴']
变成了['是否有鳞片', '是否有鳃', '是否有尾巴']
。
接下来代码继续执行到如下位置:
1 | # 获取数据集中最佳特征的所有取值 |
这句代码右侧是列表推导式,这句代码运行结果是:[1, 1, 0, 1, 0]
,它提取所有第一个分叉最佳属性的所有值,因为我们接下来要根据值来继续划分数据集了。
具体列表推导式是什么参考下面的相关函数说明
代码继续执行:
1 | # 获取最佳特征的唯一取值集合 |
这句代码执行结果就是去重,它的执行结果是:{0, 1}
,这样我们就得到了当前最佳属性的唯一取值集合,接下来就是“分叉”,第一个“叉”是按 0 来分的,第二个“叉”是按 1 来分的。
代码继续执行:
1 | # 递归遍历 |
现在我们遍历第一个“叉”,即value = 0
,我们先是完全拷贝了一份labels
给subLabels
,接下来,我们对于
第一个分叉,先做了一个划分,即splitDataSet(dataSet, bestFeat, value)
,它的运行结果是返回了:[[1, 0, 1, '鱼类'], [1, 0, 0, '鱼类']]
,也就是第一列所有value = 0
的值划分的一组(去除了第一列的值),然后形成的这个新的分组就是[[1, 0, 1, '鱼类'], [1, 0, 0, '鱼类']]
,转换成表格如下所示:
是否有鳞片 | 是否有鳃 | 是否有尾巴 | 【特征值】 |
---|---|---|---|
1 | 0 | 1 | 鱼类 |
1 | 0 | 0 | 鱼类 |
然后这组数据再次执行createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
进行分类构建,但是它很明显特征值都是一个类型的,即鱼类,所以它在执行到如下代码就返回了:
1 | # 如果传入的数据集都是一个类别,就直接返回节点 |
然后接下来回到递归遍历的地方,此时value = 1
,也就是右分叉再次执行这个循环,知道满足前面说的两个结束条件,然后才会结束,最后返回分类好的决策树。
**最终的运行结果是:{'是否有脚': {0: '鱼类', 1: {'是否有鳃': {0: '非鱼类', 1: '鱼类'}}}}
**。这个结果很难直观的来理解是吧,现在将其可视化,就是如下图所示:
好了,现在你应该已经了解了如何构造决策树了,对于晦涩难懂的输出,图更加帮助我们理解分类器的内在逻辑,接下来我们来绘制决策树,来可视化我们的决策树。
使用Graphviz
绘制树形图
需要说明的是 Python 本身并不具备绘制图形/图表的能力,我们需要通过拓展包来实现相关功能,在 Python 中有一些常用的包提供绘图相关操作:
- Matplotlib:Matplotlib是一个功能强大的绘图库,可以用于绘制各种类型的图表,包括树形图。
- NetworkX:NetworkX是一个专门用于创建、操作和研究复杂网络的Python库。它提供了一些功能强大的函数和算法,用于绘制树形图、图形布局和节点样式设置。
- Graphviz:Graphviz是一个开源的图形可视化工具包,可以用于绘制各种类型的图形,包括树形图。它使用DOT语言描述图形结构,并提供了Python接口供调用。
- anytree:anytree是一个轻量级的Python库,用于处理和操作树形数据结构。它提供了创建、遍历和操作树形结构的功能,并支持将树形结构可视化为文本、图形或其他格式。anytree提供了一些可选的渲染器,可以将树形结构绘制为图形。
从简单程度来说,使用Graphviz
包是比较简单的,所以我采用该包进行树形图绘制演示。
如何下载安装该包,此处不再做演示,绘制图形代码如下所示:
1 | import graphviz |
其渲染结果如下所示:
序列化决策树
构造决策树是很耗时的任务,如果面对的数据集很大,将会耗费更多的计算时间。然后如果我们使用创建好的决策树解决分类问题,将会大大节约时间。因此为了节省时间,最好是能够在每次执行分类时调用已经构造好的决策树。
为了解决这个问题,需要使用 Python 模块 pickle
序列化对象,代码如下所示。序列化对象可以在磁盘上存储,并在我们需要的时候读取出来。
1 | # 序列化决策树 |
这样我们在构建决策树的时候就可以序列化存储起来,然后需要的时候调用出来:
1 | dataSet = data = [ |
代码的输出结果:{'是否有脚': {0: '鱼类', 1: {'是否有鳃': {0: '非鱼类', 1: '鱼类'}}}}
。
通过上面的代码,我们可以将分类器存储在磁盘上,不必每次都需要学习一下,这也是决策树的优点之一,而相对于上一篇说明的KNN
(k-近邻算法)就无法持久化分类器。
使用决策树预测隐形眼镜类型
使用小数据集,我们就可以利用决策树学到很多知识:眼科医生是如何判断患者需要佩戴的镜片类型?一旦理解了决策树的工作原理,我们甚至也可以帮助人们判断需要佩戴的镜片类型。
关于隐形眼镜的数据集在文本的最后,相关数据部分,将数据记得保存在一个
txt
中。
现在我们在 Python 中调用如下代码:
1 | # 读取数据集 |
输出结果:{'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'young': 'hard', 'presbyopic': 'no lenses', 'pre': 'no lenses'}}, 'myope': 'hard'}}, 'no': {'age': {'young': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'pre': 'soft'}}}}, 'reduced': 'no lenses'}}
。可视化后如下所示:
从上图我们也可以发现,医生最多需要四个问题就能确定患者需要佩戴哪种类型的隐形眼镜。
上图的决策树也非常好的匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配。为了减少过度匹配,我们可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少量信息,则可以删除该节点,将它并入其他叶子节点中。我们将会在后续讨论这个问题。
相关函数说明
extend()
extend()
是 Python 列表对象的一个方法,用于将一个可迭代对象中的元素逐个添加到列表中。它会修改原始列表,将可迭代对象中的元素追加到列表的末尾。代码示例:
1 | my_list = [1, 2, 3] |
你会发现它和
append()
函数很像,具体不同看下面
append()
append()
是 Python 列表对象的一个方法,用于将一个元素添加到列表的末尾。它会修改原始列表,将元素追加到列表的最后一个位置。代码示例:
1 | a = [1,2,3] |
列表推导式
在划分数据集中,featList = [example[i] for example in dataSet]
这就是一个列表推导式,它用于提取数据集中每个样本的第 i 个特征的取值。代码示例:
1 | dataSet = [ |
set()
set()
是一个Python内置函数,用于创建一个无序、不重复元素的集合。集合是一种可变的数据类型,它可以存储各种不同的元素,但不允许有重复的元素。代码示例:
1 | numbers = [1, 2, 3, 3, 4, 5, 5] |
相关数据
隐形眼镜数据集
1 | young myope no reduced no lenses |
End
本文使用的算法是 ID3 ,它是一个好的算法但是并不完美。ID3 算法无法直接处理数值型数据,尽管我们可以量化的方法将数值型数据转换为标称型数值,但是如果存在太多特征划分,ID3 算法仍然面临其他问题。
后续我们将会学习另一个构造决策树的算法 CART,它使用基尼系数(Gini Index)来选择最优的特征进行分裂。