Keras中的LSTMcell 和LSTM 有什么区别?
在keras构建深度学习模型时,在循环层中存在LSTMcell和LSTM两个API。该如何区分呢?
1 首先看源码
LSTMcell的源码:
class LSTMCell(Layer):
def __init__(self):
*****pass*****
def build(self, input_shape):
*****pass*****
self.built = True
def call(self, inputs, states, training=None):
*****pass*****
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if self.implementation == 1:
x_i = K.dot(inputs_i, self.kernel_i)
x_i = K.bias_add(x_i, self.bias_i)
i = self.recurrent_activation(x_i + K.dot(h_tm1_i,self.recurrent_kernel_i))
z0 = z[:, :self.units]
i = self.recurrent_activation(z0)
h = o * self.activation(c)
return h, [h, c]
def get_config(self):
*****pass*****
在这段代码中,我摘取了部分,与源码比较是不同的。我们首先看类的继承,LSTMcell类继承自keras的基类Layer。在它的call方法中可以看到有很详细的单步计算过程。
接下来看LSTM的源码:
class LSTM(RNN):
@interfaces.legacy_recurrent_support
def __init__(self):
*****pass*****
cell = LSTMCell(units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
unit_forget_bias=unit_forget_bias,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation)
super(LSTM, self).__init__(cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
return super(LSTM, self).call(inputs,
mask=mask,
training=training,
initial_state=initial_state)
*****pass*****
在LSTM方法中,可以很明确的看到其继承自RNN类,即一个循环层。在它的init方法中,它调用了LSTMcell,即使用LSTMcell作为它循环过程的计算单元。
循环图层包含单元对象。单元包含用于计算每个步骤的核心代码,而循环层命令单元并执行实际的循环计算。
2 实现方法
通常,人们LSTM在代码中使用图层。或者他们使用RNN包含的图层LSTMCell。
#LSTM
model = Sequential()
model.add(LSTM(10))
...
#RNN
cells = LSTMCell(32)
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)
3 区别
LSTMcell是LSTM层的实现单元,固定将LSTMcell作为它的计算单元。而LSTMcell是一个单步的计算单元。
- LSTM是一个经常性的层
- LSTMCell是LSTM层使用的对象(恰好也是一个层),它包含一步的计算逻辑。