news 2026/4/4 16:54:06

TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(一)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(一)

tf.keras.losses.SparseCategoricalCrossentropy核心原理

SparseCategoricalCrossentropy(稀疏类别交叉熵)是 TensorFlow/Keras 中针对多分类任务的损失函数,专为稀疏标签(整数型标签,如0,1,2)设计,核心作用是衡量模型输出的类别概率分布与真实稀疏标签的「差异」,本质是交叉熵(Cross-Entropy)在稀疏标签场景下的优化实现。

一、先理解核心背景:交叉熵的本质

交叉熵源于信息论,用于衡量两个概率分布的「距离」(差异程度)。对于多分类任务:

  • 真实标签的分布是「one-hot 分布」(比如 3 分类中标签为 1,对应分布是[0,1,0]);
  • 模型输出是类别概率分布(经 Softmax 归一化后,和为 1,如[0.1,0.8,0.1])。

交叉熵的公式为:
H(p,q)=−∑i=1Cp(i)log⁡(q(i)) H(p,q) = -\sum_{i=1}^C p(i) \log(q(i))H(p,q)=i=1Cp(i)log(q(i))
其中:

  • ppp:真实标签的概率分布(one-hot 形式,仅目标类别为 1,其余为 0);
  • qqq:模型预测的类别概率分布;
  • CCC:类别总数。

由于ppp是 one-hot 分布,交叉熵可简化为:仅取目标类别对应的预测概率的负对数(因为其他项都是0×log⁡(q(i))=00 \times \log(q(i))=00×log(q(i))=0)。

二、SparseCategoricalCrossentropy 的核心适配:稀疏标签

普通的CategoricalCrossentropy要求标签是one-hot 编码(如 3 分类标签 1 对应[0,1,0]),而SparseCategoricalCrossentropy直接支持整数型稀疏标签(如 1),无需手动 one-hot 编码,核心优势是节省内存(尤其是类别数多的场景,比如 1000 类时,稀疏标签仅存 1 个整数,one-hot 需存 1000 维向量)。

三、完整计算逻辑(分两种场景)

SparseCategoricalCrossentropy的关键参数是from_logits(默认False),决定模型输出是否为「原始 logits(未归一化的得分)」或「Softmax 归一化后的概率」,两种场景的计算逻辑不同(TensorFlow 内部做了优化,避免数值不稳定)。

场景 1:from_logits=False(默认,模型输出是 Softmax 概率)

假设:

  • 类别数C=3C=3C=3
  • 真实稀疏标签y=1y=1y=1(对应目标类别是第 2 类,索引从 0 开始);
  • 模型输出 Softmax 概率q=[0.1,0.8,0.1]q=[0.1, 0.8, 0.1]q=[0.1,0.8,0.1]

计算步骤:

  1. 取真实标签对应的概率:q(y)=q(1)=0.8q(y)=q(1)=0.8q(y)=q(1)=0.8
  2. 计算负对数:−log⁡(q(y))=−log⁡(0.8)≈0.223-\log(q(y)) = -\log(0.8) ≈ 0.223log(q(y))=log(0.8)0.223
  3. 最终损失值即为该结果(批量数据会取均值/求和,由reduction参数控制)。

公式简化为:
loss=−log⁡(q(y)) \text{loss} = -\log(q(y))loss=log(q(y))

场景 2:from_logits=True(模型输出是原始 logits,推荐!)

模型输出的是未经过 Softmax 归一化的原始得分(logits,如z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]),此时 TensorFlow 不会先单独计算 Softmax(避免数值下溢/上溢),而是直接用log_softmax优化计算:

  1. 对 logits 计算log_softmaxlog⁡(Softmax(z))=z−log⁡(∑i=1Cezi)\log(\text{Softmax}(z)) = z - \log(\sum_{i=1}^C e^{z_i})log(Softmax(z))=zlog(i=1Cezi)
  2. 取真实标签对应的项,取负数即为损失:
    loss=−(zy−log⁡(∑i=1Cezi)) \text{loss} = - \left( z_y - \log(\sum_{i=1}^C e^{z_i}) \right)loss=(zylog(i=1Cezi))

示例计算(z=[1.0,3.0,0.5],y=1z=[1.0, 3.0, 0.5], y=1z=[1.0,3.0,0.5],y=1):

  • 先算∑ezi=e1.0+e3.0+e0.5≈2.718+20.085+1.648≈24.451\sum e^{z_i} = e^{1.0} + e^{3.0} + e^{0.5} ≈ 2.718 + 20.085 + 1.648 ≈ 24.451ezi=e1.0+e3.0+e0.52.718+20.085+1.64824.451
  • log⁡(24.451)≈3.200\log(24.451) ≈ 3.200log(24.451)3.200
  • log⁡(Softmax(z))1=3.0−3.200=−0.200\log(\text{Softmax}(z))_1 = 3.0 - 3.200 = -0.200log(Softmax(z))1=3.03.200=0.200
  • 损失值:−(−0.200)=0.200-(-0.200) = 0.200(0.200)=0.200

为什么推荐from_logits=True
Softmax 对大 logits 会产生e大值e^{大值}e大值(如e100e^{100}e100溢出),而log_softmax直接通过代数变换避免了单独计算 Softmax,提升数值稳定性。

四、批量数据的损失归约

实际训练中输入是批量数据(batch),损失会通过reduction参数归约(默认AUTO,等价于SUM_OVER_BATCH_SIZE):

  • 对每个样本计算损失值;
  • 求批量内所有样本损失的均值(或求和,取决于reduction)。

示例(batch_size=2):

样本稀疏标签模型概率单样本损失
11[0.1,0.8,0.1]0.223
20[0.9,0.05,0.05]0.105
批量损失 = (0.223 + 0.105) / 2 ≈ 0.164。

五、关键参数解析

参数作用示例
from_logits是否输入为原始 logits(非 Softmax 概率)from_logits=True(推荐)
reduction损失归约方式:
-NONE:返回每个样本的损失
-SUM:批量损失求和
-SUM_OVER_BATCH_SIZE:批量损失求均值
reduction="sum_over_batch_size"
ignore_index忽略指定标签(计算损失时跳过),适用于样本标注缺失场景ignore_index=-1
axis类别维度(默认 -1,即最后一维是类别)模型输出形状(batch, 3)时,axis=-1 对应 3 个类别

六、与CategoricalCrossentropy的对比

特性SparseCategoricalCrossentropyCategoricalCrossentropy
标签格式整数型稀疏标签(如 1,2,3)one-hot 编码标签(如 [0,1,0])
内存占用低(仅存整数)高(类别数维向量)
适用场景类别数多、标签天然为整数(如图像分类的类别索引)标签已做 one-hot 编码
核心公式同交叉熵,但直接取整数标签对应项交叉熵原始公式(遍历所有类别)

七、注意事项

  1. 标签范围:稀疏标签必须是[0,C−1][0, C-1][0,C1]范围内的整数(C 是类别数),否则会报错;
  2. 数值稳定性:优先设置from_logits=True,避免 Softmax 导致的数值溢出;
  3. 多标签任务:该损失适用于「单标签多分类」(每个样本仅属于一个类别),多标签任务需用BinaryCrossentropy

示例代码验证

importtensorflowastf# 1. 定义损失函数(from_logits=True,模型输出logits)loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)# 2. 模拟批量数据(batch_size=2,类别数=3)y_true=tf.constant([1,0])# 稀疏标签y_pred_logits=tf.constant([[1.0,3.0,0.5],[5.0,1.0,0.1]])# 模型输出logits# 3. 计算损失loss=loss_fn(y_true,y_pred_logits)print("批量损失值:",loss.numpy())# 输出约 0.15(手动计算验证)

综上,SparseCategoricalCrossentropy本质是「多分类交叉熵」在稀疏标签下的高效实现,核心是通过直接索引整数标签避免 one-hot 编码,同时优化数值计算保证稳定性,是单标签多分类任务的首选损失函数之一。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/26 20:44:38

【AI模型移动端落地新纪元】:Open-AutoGLM手机部署实战经验全公开

第一章:Open-AutoGLM移动端落地的时代背景随着人工智能技术的迅猛发展,大语言模型(LLM)逐步从云端向终端设备迁移。Open-AutoGLM作为面向移动场景优化的开源自动推理框架,正是在这一趋势下应运而生。其核心目标是将强大…

作者头像 李华
网站建设 2026/4/4 8:58:41

Iwara视频下载与高效管理终极解决方案

Iwara视频下载与高效管理终极解决方案 【免费下载链接】IwaraDownloadTool Iwara 下载工具 | Iwara Downloader 项目地址: https://gitcode.com/gh_mirrors/iw/IwaraDownloadTool 你是否曾为Iwara视频下载的繁琐流程而烦恼?发现心仪的视频却无法批量保存&…

作者头像 李华
网站建设 2026/3/17 5:33:57

资源嗅探工具实战指南:轻松捕获网页媒体资源,告别下载烦恼

你是否也遇到过这样的情况?😊 看到喜欢的在线视频想保存下来,却发现网站没有提供下载按钮;想要收藏精彩的音乐片段,却不知道如何获取源文件;或者想在手机和电脑间快速传输资源,却找不到便捷的方…

作者头像 李华
网站建设 2026/4/3 19:56:09

你真的懂Open-AutoGLM吗?3个关键问题揭示其背后隐藏的设计哲学

第一章:你真的懂Open-AutoGLM吗?重新审视其本质与定位Open-AutoGLM 并非一个简单的开源模型,而是一种融合了自动化推理、图学习与语言生成的新型智能架构。它试图打破传统大语言模型在任务泛化与结构理解上的边界,通过引入动态图构…

作者头像 李华
网站建设 2026/3/12 22:23:23

如何快速掌握IwaraDownloadTool:面向新手的完整使用教程

IwaraDownloadTool是一款专为Iwara平台设计的高效视频下载工具,能够帮助用户轻松保存喜爱的视频内容。无论你是初次接触还是想要提升下载效率,本教程都将为你提供全面的操作指导。 【免费下载链接】IwaraDownloadTool Iwara 下载工具 | Iwara Downloader…

作者头像 李华
网站建设 2026/3/28 15:46:17

AtCoder Beginner Contest竞赛题解 | 洛谷 AT_abc437_b Tombola

​欢迎大家订阅我的专栏:算法题解:C与Python实现! 本专栏旨在帮助大家从基础到进阶 ,逐步提升编程能力,助力信息学竞赛备战! 专栏特色 1.经典算法练习:根据信息学竞赛大纲,精心挑选…

作者头像 李华