RuntimeError: Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 'weight'解决办法

标签:Pytorch, dl
RuntimeError: Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 'weight'

今天遇到这个问题,一开始以为是输入的Tensor的问题,后来一检查发现是Model中的weight的数据类型跟输入不同的关系。

输入的是DoubleTensor类型,而模型默认的是FloatTensor类型,导致数据类型不一致报错。

解决的办法就是在模型创建和输入的时候统一数据类型,比如:

		self.cnn7=torch.nn.Conv2d(1, 1, (7,100), stride=(1,1),padding=(3,0)).float()
		self.linear=torch.nn.Linear(3,2).float()

另外值得注意的一点是FloatTensor在CPU中的执行效率比DoubleTensor更高。因为尽量采用FloatTensor类型。


原创文章,转载请注明出处!
本文链接:http://youran.tech/posts/pytorch-Expected-object-of-type-torch-DoubleTensor-argument.html
上篇: opencv中如何合并多张图片到一张图片中
下篇: Microsoft COCO-Common Objects in Context(中文版)

留言: