pytorch中的TensorDataset和DataLoader

news/2024/10/8 20:44:09 标签: pytorch, 深度学习, 人工智能

TensorDataset 详解

TensorDataset 主要用于将多个 Tensor 组合在一起,方便对数据进行统一处理。它可以用于简单地将特征和标签配对,也可以将多个特征张量组合在一起。

1. 将特征和标签组合

假设我们有一组图像数据(特征)和对应的标签,我们可以将它们组合成一个 TensorDataset

import torch
from torch.utils.data import TensorDataset

# 创建输入数据(图像)和标签
images = torch.randn(100, 3, 28, 28)  # 100张图像,每张图像3通道,28x28像素
labels = torch.randint(0, 10, (100,))  # 100个标签,范围在0到9之间

# 创建 TensorDataset
dataset = TensorDataset(images, labels)

# 访问数据集中的特定样本
sample_image, sample_label = dataset[0]
print(f"Sample Image Shape: {sample_image.shape}")  # 输出: Sample Image Shape: torch.Size([3, 28, 28])
print(f"Sample Label: {sample_label}")  # 输出: Sample Label: 3

在这个例子中,我们创建了一个包含100张图像和对应标签的 TensorDataset。通过 dataset[0],我们可以访问第一个样本的图像和标签。

2. 组合多个特征张量

除了将特征和标签组合,TensorDataset 还可以将多个特征张量组合在一起。例如,假设我们有两个不同的特征张量,我们可以将它们组合成一个 TensorDataset

# 创建两个特征张量
feature1 = torch.randn(100, 50)  # 100个样本,每个样本50维
feature2 = torch.randn(100, 30)  # 100个样本,每个样本30维

# 创建 TensorDataset
dataset = TensorDataset(feature1, feature2)

# 访问数据集中的特定样本
sample_feature1, sample_feature2 = dataset[0]
print(f"Sample Feature1 Shape: {sample_feature1.shape}")  # 输出: Sample Feature1 Shape: torch.Size([50])
print(f"Sample Feature2 Shape: {sample_feature2.shape}")  # 输出: Sample Feature2 Shape: torch.Size([30])

在这个例子中,我们创建了一个包含两个特征张量的 TensorDataset,并通过 dataset[0] 访问第一个样本的两个特征。

DataLoader 详解

DataLoader 主要用于批量加载数据,并支持多种数据处理功能,如随机打乱、多线程加载等。

1. 批量处理数据

DataLoader 可以将数据集划分为多个批次(batch),便于模型训练。

from torch.utils.data import DataLoader

# 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=False)

# 遍历 DataLoader
for batch_features, batch_labels in train_loader:
    print(f"Batch Features Shape: {batch_features.shape}")  # 输出: Batch Features Shape: torch.Size([32, 3, 28, 28])
    print(f"Batch Labels Shape: {batch_labels.shape}")  # 输出: Batch Labels Shape: torch.Size([32])
    # 这里可以进行训练操作,如前向传播、反向传播等

在这个例子中,train_loader 将数据集划分为大小为32的批次。通过遍历 train_loader,我们可以轻松地获取每个批次的特征和标签。

2. 数据打乱

DataLoader 可以通过设置 shuffle=True 来在每个 epoch 开始时随机打乱数据,避免模型学习到数据的顺序。

# 创建 DataLoader,并设置 shuffle=True
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 遍历 DataLoader
for epoch in range(2):  # 假设我们要训练两个 epoch
    for batch_features, batch_labels in train_loader:
        print(f"Epoch {epoch}, Batch Features Shape: {batch_features.shape}")
        # 这里可以进行训练操作

在这个例子中,每次 epoch 开始时,数据都会被随机打乱,确保模型不会受到数据顺序的影响。

3. 多线程加载

DataLoader 支持通过设置 num_workers 参数来使用多线程并行加载数据,加快数据读取速度。

# 创建 DataLoader,并设置 num_workers=4
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 遍历 DataLoader
for batch_features, batch_labels in train_loader:
    print(f"Batch Features Shape: {batch_features.shape}")
    # 这里可以进行训练操作

在这个例子中,我们设置了 num_workers=4,表示使用4个线程来并行加载数据,从而加快数据读取速度。

结合使用 TensorDataset 和 DataLoader

以下是一个完整的示例,展示了如何结合使用 TensorDataset 和 DataLoader 进行数据加载和训练。

import torch
from torch.utils.data import TensorDataset, DataLoader

# 创建输入数据和标签
images = torch.randn(1000, 3, 28, 28)  # 1000张图像,每张图像3通道,28x28像素
labels = torch.randint(0, 10, (1000,))  # 1000个标签,范围在0到9之间

# 创建 TensorDataset
dataset = TensorDataset(images, labels)

# 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 遍历 DataLoader 进行训练
for epoch in range(2):
    for batch_images, batch_labels in train_loader:
        print(f"Epoch {epoch}, Batch Images Shape: {batch_images.shape}")
        print(f"Epoch {epoch}, Batch Labels Shape: {batch_labels.shape}")
        # 这里可以进行训练操作,如前向传播、反向传播等

在这个例子中,我们首先使用 TensorDataset 将图像和标签组合在一起,然后通过 DataLoader 进行批量加载和训练。通过设置 shuffle=True 和 num_workers=4,我们实现了数据的随机打乱和多线程加载。

总结

  • TensorDataset 用于将多个 Tensor 组合在一起,方便对数据进行统一处理。
    • 可以组合特征和标签。
    • 可以组合多个特征张量。
  • DataLoader 用于批量加载数据,支持多种数据处理功能。
    • 支持批量处理数据。
    • 支持数据打乱。
    • 支持多线程加载。


http://www.niftyadmin.cn/n/5694766.html

相关文章

【RTCP】Interarrival Jitter: 到达间隔抖动的举例说明

Interarrival Jitter: 32位,表示接收到的数据包间隔的抖动,用于评估网络延迟变化情况。给出一些具体实例,帮助理解“Interarrival Jitter”(到达间隔抖动)是指接收到连续数据包之间时间间隔的变化情况,这个指标用于评估网络的稳定性和延迟的波动情况。在RTP中,接收端会计…

【大数据】在线分析、近线分析与离线分析

文章目录 1. 在线分析(Online Analytics)定义特点应用场景技术栈 2. 近线分析(Nearline Analytics)定义特点应用场景技术栈 3. 离线分析(Offline Analytics)定义特点应用场景技术栈 总结 在线分析&#xff…

vue2与vue3知识点

1.vue2(optionsAPI)选项式API 2.vue3(composition API)响应式API vue3 setup 中this是未定义(undefined)vue3中已经开始弱化this vue2通过this可以拿到vue3setup定义得值和方法 setup语法糖 ref > …

笔记整理—linux进程部分(8)线程与进程

前面用了高级IO去实现鼠标和键盘的读取&#xff0c;也说过要用多进程方式进行该操作&#xff1a; int mian(void) {int ret-1;int fd-1;char bug[100]{0};retfork();if(0ret){//子进程&#xff0c;读鼠标}if(0<ret){//父进程&#xff0c;读键盘}else{perror("fork&quo…

浅色系统B端管理系统标配,现在也卷起了可视化,挡不住呀

在 B 端管理系统的领域中&#xff0c;浅色系统一直以来都是标配之选。其简洁、清新的外观&#xff0c;给人以专业、高效的视觉感受。如今&#xff0c;浅色系统更是卷入了可视化的浪潮&#xff0c;这一趋势势不可挡。 浅色系统的优势在于它能够营造出一种舒适的视觉环境&#x…

​自动猫砂盆到底有没有必要?过来人经验:千万别再盲目选择了!

不知道大家养猫有没有一样的烦恼&#xff0c;就是上班时间到底要怎么保证猫砂盆里的猫屎被铲干净呢&#xff1f;放一天不铲的话&#xff0c;一次两次还行&#xff0c;长期这样就会害的猫砂盆内部细菌增多&#xff0c;甚至长虫&#xff0c;严重危害小猫的健康安全&#xff0c;但…

OpenHarmony(鸿蒙南向开发)——标准系统方案之瑞芯微RK3566移植案例(下)

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 持续更新中…… 概述 OpenHarmony Camera驱动模型结构 HDI Implementation&#x…

详细介绍pandas 在python中的用法

Pandas 是 Python 中非常流行的数据分析和处理库&#xff0c;特别适用于处理结构化数据。它构建在 NumPy 之上&#xff0c;提供了更高级的功能&#xff0c;例如数据清理、整理、筛选和统计分析。Pandas 的核心数据结构是 Series 和 DataFrame&#xff0c;分别用于处理一维数据和…