Keras中的LSTMcell 和LSTM 有什么区别?

在keras构建深度学习模型时,在循环层中存在LSTMcell和LSTM两个API。该如何区分呢?

1 首先看源码

LSTMcell的源码:

class LSTMCell(Layer):
	def __init__(self):                 
	 	*****pass*****

    def build(self, input_shape):
 		*****pass*****
        self.built = True

    def call(self, inputs, states, training=None):
 		*****pass*****
        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            x_i = K.dot(inputs_i, self.kernel_i)
            x_i = K.bias_add(x_i, self.bias_i)
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,self.recurrent_kernel_i))
            z0 = z[:, :self.units]
            i = self.recurrent_activation(z0)

        h = o * self.activation(c)
        return h, [h, c]

    def get_config(self):
 		*****pass*****

在这段代码中,我摘取了部分,与源码比较是不同的。我们首先看类的继承,LSTMcell类继承自keras的基类Layer。在它的call方法中可以看到有很详细的单步计算过程。
接下来看LSTM的源码:

class LSTM(RNN):
	@interfaces.legacy_recurrent_support
    def __init__(self):
		*****pass*****
        cell = LSTMCell(units,
                        activation=activation,
                        recurrent_activation=recurrent_activation,
                        use_bias=use_bias,
                        kernel_initializer=kernel_initializer,
                        recurrent_initializer=recurrent_initializer,
                        unit_forget_bias=unit_forget_bias,
                        bias_initializer=bias_initializer,
                        kernel_regularizer=kernel_regularizer,
                        recurrent_regularizer=recurrent_regularizer,
                        bias_regularizer=bias_regularizer,
                        kernel_constraint=kernel_constraint,
                        recurrent_constraint=recurrent_constraint,
                        bias_constraint=bias_constraint,
                        dropout=dropout,
                        recurrent_dropout=recurrent_dropout,
                        implementation=implementation)
        super(LSTM, self).__init__(cell,
                                   return_sequences=return_sequences,
                                   return_state=return_state,
                                   go_backwards=go_backwards,
                                   stateful=stateful,
                                   unroll=unroll,
                                   **kwargs)
        self.activity_regularizer = regularizers.get(activity_regularizer)

    def call(self, inputs, mask=None, training=None, initial_state=None):
        return super(LSTM, self).call(inputs,
                                      mask=mask,
                                      training=training,
                                      initial_state=initial_state)

    *****pass*****

在LSTM方法中,可以很明确的看到其继承自RNN类,即一个循环层。在它的init方法中,它调用了LSTMcell,即使用LSTMcell作为它循环过程的计算单元。
循环图层包含单元对象。单元包含用于计算每个步骤的核心代码,而循环层命令单元并执行实际的循环计算。

2 实现方法

通常,人们LSTM在代码中使用图层。或者他们使用RNN包含的图层LSTMCell。

#LSTM
model = Sequential()
model.add(LSTM(10))
...
#RNN
cells = LSTMCell(32)
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)

3 区别

LSTMcell是LSTM层的实现单元,固定将LSTMcell作为它的计算单元。而LSTMcell是一个单步的计算单元。
- LSTM是一个经常性的层
- LSTMCell是LSTM层使用的对象(恰好也是一个层),它包含一步的计算逻辑。
全部评论

相关推荐

MinatoWu:是这样的,说的太对了
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务