1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| import tensorflow_addons as tfa
encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32) decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32) sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)
embeddings = keras.layers.Embedding(vocab_size, embed_size) encoder_embeddings = embeddings(encoder_inputs) decoder_embeddings = embeddings(decoder_inputs)
encoder = keras.layers.LSTM(512, return_state=True) encoder_outputs, state_h, state_c = encoder(encoder_embeddings) encoder_state = [state_h, state_c]
sampler = tfa.seq2seq.sampler.TrainingSampler()
decoder_cell = keras.layers.LSTMCell(512) output_layer = keras.layers.Dense(vocab_size) decoder = tfa.seq2seq.basic_decoder.BasicDecoder(decoder_cell, sampler, output_layer=output_layer) final_outputs, final_state, final_sequence_lengths = decoder( decoder_embeddings, initial_state=encoder_state, sequence_length=sequence_lengths) Y_proba = tf.nn.softmax(final_outputs.rnn_output)
model = keras.Model(inputs=[encoder_inputs, decoder_inputs, sequence_lengths], outputs=[Y_proba])
|