对pytorch里LSTM的理解

由于毕业设计里面要用到LSTM,但是之前没搞懂怎么用,输入输出的格式,所以在这记一下。

先看一下pytorch的官方网址,我用的是1.4版本。其中对于nn.LSTM主要用到的输入参数有以下几个:

参数名 含义
input_size The number of expected features in the input x
hidden_size The number of features in the hidden state h
num_layers Number of recurrent layers

直接上两张图:

 

图1

 

图2
  • 模型参数里面的input_size,官方文档里面写得也很清楚了,是x的特征的个数。图2是普通的神经网络的结构,在RNN/LSTM中,也是类似的。input_size对应于图2的一个xinput,在图2中是4个。
  • hidden_size对应于图2里面hidden的神经元的个数,在图2中是3个。
  • num_layers是隐层的层数,对应于图1中的2层。

整个图2的结构对应于图1中为黄色方框里面的内容。

再来看input的内容:包含了input(h0, c0)

  • input的shape为(seq_len, batch, input_size)seq_len是给模型的序列的长度,对应于图1里面橙色的箭头,也就是4。batch就是batch。input_size是每个x的feature的长度,也就是图2里面的4。

  • h0c0的shape是一样的,因为公式里面每个cell都要用到上一个的hc,当cell不是每一行的第一个时,hc由前面的提供。当cell为第一个时,它们作为输入提供,在图1中有标注。当为单向的LSTM时,shape为(num_layers, batch, hidden_size)。每一层的第一个都需要,所以有一个维度为num_layers,而每一个cell(图1中的绿框或蓝框),都对应于图2,都有hidden_size个隐藏的神经元。

再来看outputoutput包含了output(h_n, c_n)

  • 当LSTM为单向时,output的shape为(seq_len, batch, hidden_size)。对应于图1里面4个蓝框向上的黑色的输出。有4个seq_len,每个cell里面都有hidden_size个节点。

  • 同样的,hncn的shape是一样的,当LSTM为单向时,为(num_layers, batch, hidden_size)。文档里面也说了,这两个是tensor containing the hidden state for t = seq_len。即图1中最后一列青色框内的的绿框和蓝框输出的hc。有num_layers个,每个都有hidden_size个。

c是用于记忆的,h是每个真正输出的,所以下面来做验证:

1
2
3
4
5
6
7
8
9
import torch
from torch import nn

#seq_len为5,batch为3,2层隐层,input的feature为10,隐层的feature为20
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

此时的输出的shape:

1
2
3
4
5
output.shape
torch.Size([5, 3, 20])

hn.shape
torch.Size([2, 3, 20])

输出为:

1
2
3
4
5
6
7
#第二层,也就是最后一层,第0个batch的所有feature
hn[1, 0, :]
tensor([-0.0782, -0.0454, -0.0965, 0.1213, -0.0044, 0.0327, -0.1001, -0.0142, 0.0388, 0.1085, -0.1335, 0.0021, 0.0531, -0.0665, 0.0056, 0.0089, 0.1538, -0.1175, 0.0617, -0.1061], grad_fn=<SliceBackward>)

#output从左往右数最后一个,第0个batch的所有feature
output[4, 0, :]
tensor([-0.0782, -0.0454, -0.0965, 0.1213, -0.0044, 0.0327, -0.1001, -0.0142, 0.0388, 0.1085, -0.1335, 0.0021, 0.0531, -0.0665, 0.0056, 0.0089, 0.1538, -0.1175, 0.0617, -0.1061], grad_fn=<SliceBackward>)

再来一下:

1
2
hn[1, 2, :] == output[4, 2, :]
tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True])