AdaBoost算法讲解以及MATLAB实现
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
李航统计学习adaboost算法例子(MATLAB)复现:https://blog.csdn.net/weixin_41146894/article/details/111316819
一、算法描述
AdaBoost,是英文"Adaptive Boosting"(自适应增强)的缩写,由Yoav Freund和Robert Schapire在1995年提出。它的自适应在于:前一个基本分类器分错的样本会得到加强,加权后的全体样本再次被用来训练下一个基本分类器。同时,在每一轮中加入一个新的弱分类器,直到达到某个预定的足够小的错误率或达到预先指定的最大迭代次数。
二、算法步骤是什么?
具体说来,整个Adaboost 迭代算法就3步:
1 初始化训练数据的权值分布。如果有N个样本,则每一个训练样本最开始时都被赋予相同的权值:1/N。
2 训练弱分类器。具体训练过程中,如果某个样本点已经被准确地分类,那么在构造下一个训练集中,它的权值就被降低;相反,如果某个样本点没有被准确地分类,那么它的权值就得到提高。然后,权值更新过的样本集被用于训练下一个分类器,整个训练过程如此迭代地进行下去。
3 将各个训练得到的弱分类器组合成强分类器。各个弱分类器的训练过程结束后,加大分类误差率小的弱分类器的权重,使其在最终的分类函数中起着较大的决定作用,而降低分类误差率大的弱分类器的权重,使其在最终的分类函数中起着较小的决定作用。换言之,误差率低的弱分类器在最终分类器中占的权重较大,否则较小。
三、数据集说明(可以自己增加数据集)
该数据集是参考网上设计的,数据量偏少,所以算法迭代次数只有几次。这里采用的数据集共有两个特征,X_1,X_2,对应的点组成一个具体的点,例如(5,2),(2,2) 其对应的类别为分别为1,1.
四、算法实现MATLAB
1.主函数
clc
clear
%数据集,自己设定,也可以输入,但是这里只适合二分类
X_2 = [5 2 1 6 8 5 9 7 8 2 1];
X_1 = [1 2 3 4 6 6 7 8 9 10 10];
%初始权重
w = [1 1 1 1 1 1 1 1 1 1 1];
%对应的输出
H = [];
%对应的目标函数
Y = [1 1 1 -1 -1 -1 1 1 1 -1 1];
Y_1 = Y;
% fprintf('目标分类:\n');
% fprintf('%4d',Y)
% fprintf('\n');
%对应的X
X = [0 1 2 3 4 5 6 7 8 9];
error = [];
%保存错误的个数
figure(1);
plot(X_1,X_2,'r*');
xlabel('X1');
ylabel('X2');
respose_1 = [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0];
respose_2 = [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0];
%对应的切分点
cut = [0.5 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5];
X_5 = [];
X_6 = [];
h_1 = [];
%用来保存误差
e = [];
%保存系数
a = [];
%设置最大的迭代步数
m = 20;
%开始迭代
for i=1:m
[h1,h2,e1,e2] = small_adaboost_con( X_1, X_2, cut, w, respose_1, respose_2, Y);
[Y_1,a(i),respose_1(i),respose_2(i),w] = cofficient_function(e1, e2, h1, h2, w, Y, respose_1, respose_2,cut);
H(i,:) =a(i)*Y_1;
sign(sum(H));
%记录错误的个数
error(i)=sum(abs(sign(sum(H))- Y))/2;
fprintf('\n');
fprintf('迭代第 %d 次, 错误率为: %f',i, error(i)/11);
fprintf('\n');
%判断是否还存在错误划分点
if sum(abs(sign(sum(H))- Y))==0
break;
end
end
fprintf('\n');
fprintf('迭代结束\n');
fprintf('所有弱分类函数对应的系数:\n');
fprintf('%16f',a);
fprintf('\n')
figure(2);
plot(error,'r*--');
xlabel('iteration times');
ylabel('error numbers');
legend('error numbers');
2.找到误差率最小的弱函数
function [ Y1,a1,respose1,respose2,w] = cofficient_function(e1, e2, h1, h2, w, Y, respose_1,respose_2,cut)
%{
函数功能:找到误差率最小的弱函数
输入: h1:对应X_1的弱分类函数
h2:对应X_2的弱分类函数
e1:对应X_1弱函数的误差率
e2:对应X_2弱函数的误差率
cut:划分点,这里是二分类
w:当前的权值
repose_1:记录X_1的划分点
repose_2:记录X_2的划分点
Y:期望输出
输出: Y1:误差率最小的弱函数
a1: 对应其系数
respose1:当前X_1使用的划分点,为了防止重复使用同一划分点
rrespose2:当前X_2使用的划分点
%}
%找出e1最小误差
minE_1 = min(e1);
%算弱分类器前的系数
minI_1 = find(e1 == minE_1);
%判断是否有两个或则更多的最小误差
if (size(minI_1,2) > 1)
minI_1 = minI_1(1);
end
minE_1;
minI_1;
%找出最小e2误差
minE_2 = min(e2);
%算弱分类器前的系数
minI_2 = find(e2 == minE_2);
%判断是否有两个或则更多的最小误差
if (size(minI_2,2) > 1)
minI_2 = minI_2(1);
end
minE_2;
minI_2;
if minE_2< minE_1
minI = minE_2;
minI_Y = minI_2;
minI_X = 0;
else
minI = minE_1;
minI_X = minI_1;
minI_Y = 0;
end
if minI_X == 0
%得到系数
% fprintf('Y')
a = log((1 - minE_2)/minE_2)/2;
%更新权值
Y_1 = h2(minI_Y,:);
Z = sum(w.*exp(-a * h2(minI_Y,:).*Y));
w = w/Z .* exp(-a * h2(minI_Y,:).*Y);
respose_2 = minI_Y;
%保存已经使用过的判断点
hold on;
plot([0,10],[cut(minI_Y),cut(minI_Y)]);
else
respose_2 = 0;
end
if minI_Y == 0
% fprintf('X')
%得到系数
a = log((1 - minE_1)/minE_1)/2;
%更新权值
Y_1 = h1(minI_X,:);
Z = sum(w.*exp(-a * h1(minI_X,:).*Y));
w = w/Z .* exp(-a * h1(minI_X,:).*Y);
%保存已经使用过的判断点
respose_1 = minI_X;
hold on;
plot([cut(minI_X),cut(minI_X)],[0,10]);
else
respose_1=0;
end
Y1=Y_1;
a1 = a;
respose1 = respose_1;
respose2 = respose_2;
w1 = w;
3.分别计算弱分类函数,并计算其误差率
function [h1, h2,e1,e2] = small_adaboost_con( label_1, label_2, cut, w, repose_1, repose_2, Y)
%{
%函数功能:分别计算弱分类函数,并计算其误差率
%输入:label_1:X_1的特征值
label_2;X_2的特征值
cut:划分点,这里是二分类
w:当前的权值
repose_1:记录X_1的划分点
repose_2:记录X_2的划分点
Y:期望输出
%输出:
h1:对应X_1的弱分类函数
h2:对应X_2的弱分类函数
e1:对应X_1弱函数的误差率
e2:对应X_2弱函数的误差率
%}
%初始化权值大小,归一化到(0,1)
sumW = sum(w);
w = w ./ sumW;
for j=1:length(cut)
for k=1:length(label_1)
if label_1(k) < cut(j)
h_1(j,k) = 1;
else
h_1(j,k) = -1;
end
if label_2(k) < cut(j)
h_2(j,k) = 1;
else
h_2(j,k) = -1;
end
end
end
% %计算特征值1的误差率
for j = 1:size(h_1,1)
if isempty(find( repose_1==j))==0
e_1(j) = 999;
else
%误差率大于0.5,
e_1(j) = sum(w.*((Y - h_1(j,:)).*Y) / 2);
if e_1(j) > 0.5
e_1(j) = 1 - e_1(j);
for k=1:length(label_1)
if label_1(k) < cut(j)
h_1(j,k) = -1;
else
h_1(j,k) = 1;
end
end
end
end
end
%计算特征值2的误差率
for j = 1:size(h_2,1)
if isempty(find( repose_2==j))==0
e_2(j) = 999;
else
%误差率大于0.5,
e_2(j) = sum(w.*((Y - h_2(j,:)).*Y) / 2);
if e_2(j) > 0.5
e_2(j) = 1 - e_2(j);
for k=1:length(label_2)
if label_2(k) < cut(j)
h_2(j,k) = -1;
else
h_2(j,k) = 1;
end
end
end
end
end
h1 = h_1;
h2 = h_2;
e1 = e_1;
e2 = e_2;
该处使用的url网络请求的数据。
六、结果演示