pytorch自定义Dataset
admin
2024-01-17 12:47:54
0

因为需要读取大量数据到神经网络里进行训练,之前一直使用的keras.fit不管用了,后来发现pytorch自带的Dataset和Dataloader能很好的解决这个问题。如果使用tensorflow的话,需要使用tf.data.Dataset.from_tensor_slices().map()方法或者使用队列来解决这个问题,
tensorflow自定义Dataset教程链接:
http://www.51zixue.net/TensorFlow/765.html

在网上找了一些教程,只写了一些基础的代码,没有讲清楚为啥这么写,有些bug也没有提示。
这里写一下我自己的理解:
首先自定义Dataset必须要写一个继承from torch.utils.data import Dataset的类,其中除了init方法以外还有两个方法,__getitem__()和__len__(),可以这么理解:在使用pytorch自带的Dataloader把Dataset包裹起来调用的时候,会认为这个Dataset一共有的数据量就是__len__()的返回值,比如Dataloader的batch参数为8,即一次读取8个数据,它就会产生8个不同的数值,把这些数值作为__getitem__()的参数输入进去调用,然后把返回的每次返回的数据,共8个,打包好来给用户。

其中,get_item()的返回值也没要必须是(一个数据+一个label)的形式,只要有返回值就可以,只不过相对应的,在遍历Dataloader,其实也就是在遍历这些返回值,只要做好相应处理即可

其中我遇到了两个报错是和这部分有关的

ValueError: num_samples should be a positive integer value, but got num_samples=0

这个原因比较简单,就是 __len__(self)返回值是0,导致程序认为不存在样本数量,关注修改这部分即可

第二个:
UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.

这个报错的原因比较复杂,主要原因就是报错里说的,The given NumPy array is not writeable。我在本地测试正常,但是把程序部署到gpu算力平台上时出现了这个问题,解决方法是在

__getitem__(self, index)

这个函数的返回值里,把原来返回的feature用np.array()包裹,注意feature原本就是numpy数组,这里再调用一次np.array是为了达到copy的效果,从而解决这个问题。

下面附上整段的代码

# 准备pytorch的数据
from torch.utils.data import Dataset, DataLoader
from OSutils import get_data_path, load_jsondata
from ByteSequencesFeature import byte_sequences_feature
from torch.utils.data import DataLoader
import numpy as np
import torchdef data_loader_multilabel(file_path='', label_dict={}):# 用于读取多标签的情况file_md5 = file_path.split('/')[-1]return byte_sequences_feature(file_path), label_dict.get(file_md5)def data_loader(file_path='', label_dict={}):# 用于读取单标签的情况file_md5 = file_path.split('/')[-1]if file_md5 in label_dict:return byte_sequences_feature(file_path), 1else:return byte_sequences_feature(file_path), 0class MalconvDataSet(Dataset):def __init__(self, black_samples_dir="black_samples/", white_samples_dir='white_samples/',label_dict_path='label_dict.json', label_type="single", valid=False, valid_size=0.2, seed=207):self.file_list = get_data_path(black_samples_dir)self.loader = data_loader_multilabelif label_type == "single":self.loader = data_loaderself.file_list += get_data_path(white_samples_dir)if label_type == "predict":self.label_dict = {}self.loader = data_loaderelse:self.label_dict = load_jsondata(label_dict_path)np.random.seed(seed)np.random.shuffle(self.file_list)# 如果是需要测试集,就在原来的基础上分割# 因为设定了随机种子,所以分割的结果是一样的valid_cut = int((1 - valid_size) * len(self.file_list))if valid:self.file_list = self.file_list[valid_cut:]else:self.file_list = self.file_list[:valid_cut]def __getitem__(self, index):file_path = self.file_list[index]feature, label = self.loader(file_path, self.label_dict)return np.array(feature), labeldef __len__(self):return len(self.file_list)

相关内容

热门资讯

【MySQL】锁 锁 文章目录锁全局锁表级锁表锁元数据锁(MDL)意向锁AUTO-INC锁...
【内网安全】 隧道搭建穿透上线... 文章目录内网穿透-Ngrok-入门-上线1、服务端配置:2、客户端连接服务端ÿ...
GCN的几种模型复现笔记 引言 本篇笔记紧接上文,主要是上一篇看写了快2w字,再去接入代码感觉有点...
数据分页展示逻辑 import java.util.Arrays;import java.util.List;impo...
Redis为什么选择单线程?R... 目录专栏导读一、Redis版本迭代二、Redis4.0之前为什么一直采用单线程?三、R...
【已解决】ERROR: Cou... 正确指令: pip install pyyaml
关于测试,我发现了哪些新大陆 关于测试 平常也只是听说过一些关于测试的术语,但并没有使用过测试工具。偶然看到编程老师...
Lock 接口解读 前置知识点Synchronized synchronized 是 Java 中的关键字,...
Win7 专业版安装中文包、汉... 参考资料:http://www.metsky.com/archives/350.htm...
3 ROS1通讯编程提高(1) 3 ROS1通讯编程提高3.1 使用VS Code编译ROS13.1.1 VS Code的安装和配置...
大模型未来趋势 大模型是人工智能领域的重要发展趋势之一,未来有着广阔的应用前景和发展空间。以下是大模型未来的趋势和展...
python实战应用讲解-【n... 目录 如何在Python中计算残余的平方和 方法1:使用其Base公式 方法2:使用statsmod...
学习u-boot 需要了解的m... 一、常用函数 1. origin 函数 origin 函数的返回值就是变量来源。使用格式如下...
常用python爬虫库介绍与简... 通用 urllib -网络库(stdlib)。 requests -网络库。 grab – 网络库&...
药品批准文号查询|药融云-中国... 药品批文是国家食品药品监督管理局(NMPA)对药品的审评和批准的证明文件...
【2023-03-22】SRS... 【2023-03-22】SRS推流搭配FFmpeg实现目标检测 说明: 外侧测试使用SRS播放器测...
有限元三角形单元的等效节点力 文章目录前言一、重新复习一下有限元三角形单元的理论1、三角形单元的形函数(Nÿ...
初级算法-哈希表 主要记录算法和数据结构学习笔记,新的一年更上一层楼! 初级算法-哈希表...
进程间通信【Linux】 1. 进程间通信 1.1 什么是进程间通信 在 Linux 系统中,进程间通信...
【Docker】P3 Dock... Docker数据卷、宿主机与挂载数据卷的概念及作用挂载宿主机配置数据卷挂载操作示例一个容器挂载多个目...