【算法】反向传播算法

news/2024/9/29 18:18:29 标签: 算法, 深度学习

David Rumelhart 是人工智能领域的先驱之一,他与 James McClelland 等人在1986年通过其著作《Parallel Distributed Processing: Explorations in the Microstructure of Cognition》详细介绍了反向传播算法(Backpropagation),这一算法为多层神经网络的训练提供了有效的途径,是深度学习发展的重要里程碑之一。

反向传播算法的核心思想:

反向传播(Backpropagation)算法是基于梯度下降法的一种优化算法,用来训练多层感知器(MLP)等神经网络模型。它的主要思想是,通过逐层计算误差的梯度,并向网络的反方向传播这些误差,更新神经网络的权重,以最小化损失函数。

以下是反向传播算法的基本步骤及其对应的数学公式:

一、前向传播(Forward Propagation)

前向传播的目的是计算神经网络的输出。对于第 l 层的线性组合和激活值:

1. 线性组合:

在这里插入图片描述
这里,W(l) 是权重矩阵,a(l−1) 是第 l−1 层的激活值,b(l) 是偏置项。

2. 激活值:

然后通过激活函数 g,得到第 l 层的激活值:
在这里插入图片描述

二、 损失函数计算(Loss Function Calculation)

网络的输出和真实标签(目标值)之间的差异通过损失函数来度量。例如,对于回归问题常用均方误差(MSE),对于分类问题常用交叉熵损失(Cross Entropy)。假设损失函数为 L,我们的目标是最小化 L。
在这里插入图片描述
其中 a(L) 是网络的输出,y 是真实标签。

三、 反向传播(Backpropagation)
1. 输出层的误差

每一层的误差通常用符号 δ(l)表示,对于输出层(假设是第 L 层),误差是最直接的,因为我们可以根据损失函数和网络的预测值计算它。

输出层的误差计算公式为:
在这里插入图片描述

其中:

  • ∂L/∂a(L) 是损失函数 L 对输出值 a(L) 的导数。这个值取决于损失函数的形式。例如,对于均方误差(MSE)损失函数:
    在这里插入图片描述

    其导数为:
    在这里插入图片描述

    对于交叉熵损失(Cross Entropy),其导数形式不同,但基本过程相同。

  • ∂a(L)/∂z(L)​ 是激活函数 g(z(L)) 的导数:
    在这里插入图片描述

因此,输出层的误差可以写成:
在这里插入图片描述

2. 隐藏层的误差

对于隐藏层,我们仍然使用链式法则来计算损失函数对 z(l) 的导数。具体来说,假设我们已经知道第 l+1 层的误差 δ(l+1)=∂L/∂z(l+1),那么第 l 层的 z(l) 导数可以通过反向传播从第 l+1 层传递下来。

使用链式法则,隐藏层 z(l) 的导数为:
在这里插入图片描述

  • 计算∂L/∂a(l)
    使用链式法则,损失函数 L 对隐藏层 a(l) 的导数为:
    在这里插入图片描述
    根据线性组合的公式 z(l+1)=W(l+1)a(l)+b(l+1),z(l+1) 对 a(l) 的导数为:
    在这里插入图片描述
    因此,∂L/∂a(l)为:
    在这里插入图片描述
    为了保持一致性,我们通常将 W(l+1) 转置,使得矩阵运算中的维度保持一致。

  • 计算∂a(l)/∂z(l)

a(l) 是 z(l)z(l) 通过激活函数 gg 后的结果,因此:
在这里插入图片描述

综上,对于隐藏层的第 l 层,其误差计算公式为:
在这里插入图片描述
∘ 表示逐元素相乘(Hadamard 乘积),激活函数是逐元素应用到每个神经元输出的,而不是对整个向量进行操作。因此,第 l 层的每个神经元在反向传播时都会依赖于其对应的激活函数导数。

3. 计算梯度

一旦我们得到了每一层的误差 δ(l),我们就可以计算每一层权重和偏置的梯度。梯度是描述损失函数相对于权重或偏置变化率的一个量。在反向传播阶段,我们通过链式法则计算损失函数对各层权重 W(l) 和偏置 b(l) 的梯度,即:
在这里插入图片描述
这些梯度表示每个权重和偏置对最终损失 L 的影响。它们通过链式法则逐层向前回传,详细步骤如下:

3.1 对于权重矩阵 W(l)

通过链式法则计算损失函数对权重 W(l) 的导数:
在这里插入图片描述
得到结果:
在这里插入图片描述
这里,a{(l-1)}T 是上一层的激活值的转置,目的是确保矩阵的维度正确。由于 W(l) 是一个矩阵,通常 a(l−1) 是一个列向量,因此 a{(l-1)}T 是一个行向量。

3.2 对于偏置向量 b(l)

通过链式法则,我们可以计算损失函数 L 对偏置 b(l) 的导数:
在这里插入图片描述
在线性组合公式中,偏置 b(l) 是直接加到每个神经元的线性组合 z(l) 中的。因此,z(l) 对 b(l) 的导数是 1:
在这里插入图片描述
所以:
在这里插入图片描述

四、 权重和偏置更新(Weight and Bias Update)

使用梯度下降法,根据反向传播计算得到的梯度更新权重和偏置。

1. 权重更新公式:

对于第 l 层的权重 W(l),更新公式为:

在这里插入图片描述
其中:

  • η 是学习率。
  • ∂W(l)∂E​=δ(l)(a(l−1))T 是损失函数对第 l 层权重的梯度。
2. 偏置更新公式:

类似地,第 l 层的偏置 b(l) 更新公式为:

在这里插入图片描述

五、 循环迭代

通过多次迭代(通常称为训练迭代(epochs)),重复进行前向传播、损失函数计算、反向传播以及权重和偏置的更新,直到网络收敛,即损失函数的值不再显著下降,或者达到了预设的迭代次数。

Rumelhart 对反向传播算法的贡献:

David Rumelhart 及其同事的主要贡献在于:

  • 他们系统化地提出了反向传播算法,使得该算法可以有效应用于多层神经网络的训练,解决了之前单层感知器模型的局限性。
  • 他们展示了如何通过反向传播算法训练深层网络,使得网络能够从数据中学习复杂的模式表示。这为后来的深度学习发展奠定了基础。

反向传播的意义与局限:

反向传播算法是现代深度学习的核心之一,它使得多层神经网络能够成功训练,解决了许多复杂的任务(如图像识别、语音识别等)。但是,它也有一些局限性,例如:

  • 梯度消失问题(vanishing gradient):在深层神经网络中,反向传播的梯度逐渐减小,导致前几层权重更新非常缓慢。
  • 训练时间长:当网络层数增加或数据集规模扩大时,训练时间可能会变得非常长。

尽管如此,反向传播算法依然是当今神经网络训练的基础,配合现代改进的优化方法(如Adam、RMSprop等)和技术(如Batch Normalization、Dropout等),反向传播已经极大地提升了神经网络的学习效率和表现。


http://www.niftyadmin.cn/n/5683510.html

相关文章

哈希表(HashMap、HashSet)

文章目录 一、 什么是哈希表二、 哈希冲突2.1 为什么会出现冲突2.2 如何避免出现冲突2.3 出现冲突如何解决 三、模拟实现哈希桶/开散列(整型数据)3.1 结构3.2 插入元素3.3 获取元素 四、模拟实现哈希桶/开散列(泛型)4.1 结构4.2 插…

python绘制动态残差图,plot交互模式

python绘制动态残差图 动态刷新数据,交互模式 # 开启交互模式plt.ion()# 创建初始数据x_line [0, 1]err_wage [0, 10]# 创建图形和轴fig, ax plt.subplots()line, ax.plot(x_line, err_wage, b-) # b-表示蓝色实线# ax.set_xlim(0, 20) # 设置x轴的范围# ax.…

甘肃非遗文化网站:Spring Boot开发实战

3 系统分析 当用户确定开发一款程序时,是需要遵循下面的顺序进行工作,概括为:系统分析–>系统设计–>系统开发–>系统测试,无论这个过程是否有变更或者迭代,都是按照这样的顺序开展工作的。系统分析就是分析系…

什么是SQL注入?

SQL注入是一种安全漏洞,攻击者通过在应用程序的输入字段中插入恶意SQL代码,从而操控数据库。此类攻击通常利用应用程序未对用户输入进行适当验证和清理的弱点。 工作原理: 输入字段:攻击者在登录表单或搜索框等输入区域插入恶意…

js实现两个轴直线插补圆弧插补

效果图 源代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Multi-Axis Motion with Canvas</title><style>body {margin: 0;}#controls {position: absolute;top: 10px;right: 10px;…

不同领域的常见 OOD(Out-of-Distribution)数据集例子

以下是几个来自不同领域的常见 OOD&#xff08;Out-of-Distribution&#xff09;数据集例子&#xff0c;这些数据集常用于测试和研究模型在分布变化或分布外数据上的泛化能力&#xff1a; 1. 计算机视觉领域 CIFAR-10 vs. CIFAR-10-C / CIFAR-100-C: 描述&#xff1a;CIFAR-10…

滚雪球学MySQL[6.1讲]:数据备份与恢复

全文目录&#xff1a; 前言6. 数据备份与恢复6.1 备份的基础知识6.1.1 备份的重要性6.1.2 备份的类型 6.2 备份策略6.2.1 完全备份与增量备份结合6.2.2 定期检查备份有效性6.2.3 异地备份 6.3 MySQL备份工具6.3.1 mysqldump6.3.2 mysqlhotcopy6.3.3 Percona XtraBackup 6.4 数据…

LSTM预测未来30天销售额

加入深度实战社区:www.zzgcz.com&#xff0c;免费学习所有深度学习实战项目。 1. 项目简介 本项目旨在利用深度学习中的长短期记忆网络&#xff08;LSTM&#xff09;来预测未来30天的销售额。LSTM模型能够处理时序数据中的长期依赖问题&#xff0c;因此在销售额预测这类涉及时…