Logistic 回归 损失函数推导及参数更新

Logistic 回归 损失函数推导及参数更新

Logistic 回归属于广义的线性模型,联系函数为sigmoid函数

广义的线性模型为 y = g ( w T x + b ) y = g(\boldsymbol{w^Tx} + b) y=g(wTx+b),函数 g ( . ) g(.) g(.)起到了将线性回归模型预测值和真实标记联系起来的作用,称为“联系函数”(link function)。使得 g 1 ( y ) g^{-1}(y) g1(y) w T x + b \boldsymbol{w^Tx}+b wTx+b形成线性关系。

模型
z = w T x + b y = 1 1 + e z z = \boldsymbol{w^Tx}+b\\ y = \frac{1}{1+e^{-z}} z=wTx+by=1+ez1

损失函数推导
采用极大似然估计来推导损失函数。
先写出正负样本的概率预测值:
<mstyle displaystyle="true" scriptlevel="0"> p ( y = 1 x ; w , b ) </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = y </mstyle> <mstyle displaystyle="true" scriptlevel="0"> p ( y = 0 x ; w , b ) </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = 1 y </mstyle> \begin{aligned} p(y^*=1|\boldsymbol{x};\boldsymbol{w}, b) &amp;= y\\ p(y^*=0|\boldsymbol{x};\boldsymbol{w}, b) &amp;= 1-y\\ \end{aligned} p(y=1x;w,b)p(y=0x;w,b)=y=1y
统一两个式子得到:
p ( y x ; w , b ) = y y ( 1 y ) 1 y p(y^*|\boldsymbol{x};\boldsymbol{w}, b) = y^{y^*}(1-y)^{1-y^*} p(yx;w,b)=yy(1y)1y
假设样本独立且同分布,要让对每个样本的概率值更接近其所属分类,列出极大似然函数:
L ( w ; b ) = <munderover> i = 1 m </munderover> p ( y i x i , w , b ) = <munderover> i = 1 m </munderover> y i y i ( 1 y i ) 1 y i L(\boldsymbol{w};b) = \prod_{i=1}^{m}p(y_i^*|\boldsymbol{x_i}, \boldsymbol{w}, b) = \prod_{i=1}^{m}y_i^{y_i^*}(1-y_i)^{1-y_i^*} L(w;b)=i=1mp(yixi,w,b)=i=1myiyi(1yi)1yi
取对数:
l ( w , b ) = <munderover> i = 1 m </munderover> [ y i l o g ( y i ) + ( 1 y ) l o g ( 1 y i ) ] ) l(\boldsymbol{w}, b) = \sum_{i=1}^m [y_i^*log(y_i) + (1-y^*)log(1-y_i)]) l(w,b)=i=1m[yilog(yi)+(1y)log(1yi)])
损失函数一般求最小值,因此
J ( w , b ) = 1 m <munderover> i = 1 m </munderover> [ y i l o g ( y i ) + ( 1 y ) l o g ( 1 y i ) ] J(\boldsymbol{w}, b) = -\frac{1}{m}\sum_{i=1}^m \left[y_i^*log(y_i) + (1-y^*)log(1-y_i)\right] J(w,b)=m1i=1m[yilog(yi)+(1y)log(1yi)]
正好是交叉熵损失函数

梯度推导

计算图

J w j = <munderover> i = 1 m </munderover> J y i y i z i z i w j \frac{\partial J}{\partial \boldsymbol{w_j}} = \sum_{i=1}^m \frac{\partial J}{\partial y_i} \frac{\partial y_i}{\partial z_i} \frac{\partial z_i}{\partial \boldsymbol{w_j}} wjJ=i=1myiJziyiwjzi

依次计算每一项

J ( w , b ) = 1 m <munderover> i = 1 m </munderover> [ y i l o g ( y i ) + ( 1 y ) l o g ( 1 y i ) ] J(\boldsymbol{w}, b) = -\frac{1}{m}\sum_{i=1}^m \left[y_i^*log(y_i) + (1-y^*)log(1-y_i)\right] J(w,b)=m1i=1m[yilog(yi)+(1y)log(1yi)]

<mstyle displaystyle="true" scriptlevel="0"> J y i </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = 1 m ( y i y i 1 y i 1 y i ) </mstyle> <mstyle displaystyle="true" scriptlevel="0"> </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = 1 m ( y i y i y i ( 1 y i ) ) </mstyle> \begin{aligned} \frac{\partial J}{\partial y_i} &amp;= -\frac{1}{m}(\frac{y_i^*}{y_i}-\frac{1-y_i^*}{1-y_i}) \\ &amp;=-\frac{1}{m}(\frac{y_i^*-y_i}{y_i(1-y_i)}) \end{aligned} yiJ=m1(yiyi1yi1yi)=m1(yi(1yi)yiyi)

y = 1 1 + e z y = \frac{1}{1+e^{-z}} y=1+ez1

<mstyle displaystyle="true" scriptlevel="0"> y i z i </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = e z ( 1 + e z ) e z e z ( 1 + e z ) 2 </mstyle> <mstyle displaystyle="true" scriptlevel="0"> </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = y i ( 1 y i ) </mstyle> \begin{aligned} \frac{\partial y_i}{\partial z_i} &amp;= \frac{e^z(1+e^z)-e^ze^z}{(1+e^z)^2} \\ &amp;= y_i(1-y_i) \end{aligned} ziyi=(1+ez)2ez(1+ez)ezez=yi(1yi)

z = w T x + b z = \boldsymbol{w^Tx}+b z=wTx+b

z i w j = x j \frac{\partial z_i}{\partial w_j} = x_j wjzi=xj
将三个导数计算结果代入上式,得到
<mstyle displaystyle="true" scriptlevel="0"> J w j </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = <munderover> i = 1 m </munderover> 1 m ( y i y i y i ( 1 y i ) ) y i ( 1 y i ) x j </mstyle> <mstyle displaystyle="true" scriptlevel="0"> </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = 1 m <munderover> i = 1 m </munderover> ( y i y i ) x j </mstyle> \begin{aligned} \frac{\partial J}{\partial \boldsymbol{w_j}} &amp;= \sum_{i=1}^{m} -\frac{1}{m}(\frac{y_i^*-y_i}{y_i(1-y_i)}) \cdot y_i(1-y_i) \cdot x_j \\ &amp;=\frac{1}{m}\sum_{i=1}^{m}(y_i-y_i^*)x_j \end{aligned} wjJ=i=1mm1(yi(1yi)yiyi)yi(1yi)xj=m1i=1m(yiyi)xj

参数更新公式

<mstyle displaystyle="true" scriptlevel="0"> w j </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = w j α J w j </mstyle> <mstyle displaystyle="true" scriptlevel="0"> </mstyle> <mstyle displaystyle="true" scriptlevel="0"> = w j α 1 m <munderover> i = 1 m </munderover> ( y i y i ) x j </mstyle> \begin{aligned} \boldsymbol{w_j} &amp;= \boldsymbol{w_j} - \alpha \frac{\partial J}{\partial \boldsymbol{w_j}} \\ &amp;= \boldsymbol{w_j} - \alpha \frac{1}{m}\sum_{i=1}^{m}(y_i-y_i^*)x_j \end{aligned} wj=wjαwjJ=wjαm1i=1m(yiyi)xj

全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务