Understanding Text Generation With LSTM Networks Using TensorFlow

This article will help you understand how sequence-to-sequence models like LSTM can be used for text generation.

Text Generation Using LSTM


Text generation is getting much attention lately due to its generality and the various use cases it offers. LSTMs are widely used for text generation later becoming unpopular since the introduction of Transformers, but still, LSTM and other types of RNNs like GRU are used in the realm of NLP. This article will show how we can train an LSTM network from scratch to generate text.

1. What does Text Generation Mean?

Text Generation models are prediction models that predict the probability of a word to occur given the previous sequence of text. This means the model assigns a probability to each word in the vocabulary that has the chance to be the next word in the given sequence. The word with a higher probability is usually selected.

Generation occurs when we iteratively do this such that the model's output in each timestep is fed as input to itself for getting the next prediction, It is called autoregressive generation. If the model is well-trained, it makes some sensible predictions combined to form a sentence or paragraph that makes sense to us.

In any language modeling tasks involving token prediction, the idea is to predict the next word based on the combined probabilities of the previous tokens or words.

The combined probabilities of the prior words give the likelihood of a word P(y). Neural Language Models use a Neural Network to assign probability while simple models like N-grams use the maximum likelihood prediction. Neural Networks perform better than the likelihood estimation in N-Grams since they can capture the semantic meaning and context of a sequence of tokens. However, the problem of vanishing gradients is common in deep neural networks including RNNs. To solve this problem, a variant of RNNs called LSTM (Long Short-term Memory) Networks is introduced.

2. What's Happening in LSTM?

Let’s brush up a little about how LSTM works and how it can be used for text generation. Now if you are new to LSTM and need an in-depth understanding of its working, this blog post “Understanding LSTM Networks” from Christopher Olah is a masterpiece.

LSTMs (Long Short-Term Memory networks) are a type of recurrent neural network (RNN) designed to handle sequence-to-sequence tasks. They are capable of capturing and maintaining both long-term and short-term dependencies within sequential data. LSTMs process data sequentially, timestep by timestep.

Here is the overall structure of an LSTM cell look like,

The fundamental element of LSTM is the cell state C_(t — 1). It is the long-term memory of the network, storing long-term dependencies in sequences, while the hidden states record short-term dependencies, called the short-term memory of the network.

Each LSTM cell receives the cell state C_(t — 1) and hidden state h_(t — 1) from the previous cells and also gets the current input x_(t). These inputs go through a series of computations to produce the new and hidden cell states. Let’s dive into the different operations happening inside the LSTM.

The first operation inside each LSTM cell is the forget or remove operation, this section is called the forgot gate. The forgot operation allows LSTM to forget irrelevant information from the previous cell state and add relevant information from the current input. By forgetting, we don’t mean completely removing everything, that’s where the sigmoid function makes sense, it squashes information between 0 and 1, irrelevant information can be close to zero and the LSTM will forget it, on the other hand, relevant information will be close to 1 and persist in the cell state. You can see this as the first operation in LSTM at the top, the new inputs x_(t) and the hidden state h_(t — 1) are passed through a sigmoid to determine which data to forget from the previous hidden state that is irrelevant to the current input. The computation will be like this:

Here U_(f) are the weights on the hidden state and W_(f) is the weights on the input gate. The dot product of both weights is in the hidden state and the input is passed through a sigmoid activation to filter out irrelevant information.

The second operation involves extracting and adding new information a section called the add gate. For the latest information to be added, we need to filter out the previous information that is irrelevant, ie, we only need to pass the relevant information, Again we use a sigmoid function to filter the relevant from the irrelevant and pass it through a tanh(Hyperbolic Tangent) function to add new information to the cell state.

Here is the operation for extracting the information for adding,

This is the operation to filter out the information from the previous and current ones,

Now this operation adds the new filtered information to the previous cell state to get the new cell state,

Now in the third operation, things are a little tricky. Now if you take a closer look at the cell state, The cell state contains the information about the previous contexts that are not removed and also the current new details added, ie, we have to filter long-term from short-term and pass the short-term information as the next hidden state to the next cell. For this, we use a tanh activation function to extract short-term information that needs to pass as the next hidden state. The calculation will be like this,

This illustrates the complete forward computation in LSTM. A variation of backpropagation called Backpropagation Through Time (BPTT) is used to train LSTM. We will not discuss BPTT here since that is not our goal of this article and it is also a little mathy. If you want to know how BPTT works, I’ve written an article about it, you can check that here: BPTT Derivation

3. LSTM for text generation

Now let’s discuss how we will perform text generation with LSTM. As we said earlier, generating text involves iterative prediction of tokens one after the other. For this iterative prediction, we need to pass the previous sequences to get the current one, here is how it can be done.

Text Generation with LSTM, Source: Speech and Language Processing. An Introduction to Natural Language Processing, Computational Linguistics, and Speech Recognition, Third Edition

Here the first word passed is called the start token <s>. After passing it, the network predicts “So”, then this is again passed as input in the next time step and the network predicts “long” and so on. The output of the LSTM in each time step is a probability distribution of which token is more likely to be the next in the sequence. Now let’s look at each part of the process,

3.1 Inputs and Embeddings

The input to an LSTM is an integer token representing a word or a subword in the entire vocabulary. For example, consider the vocabulary with 4 words, [we, are, learning, machine]. If the sequence goes like “We are Learning Machine Learning”, the integer representation will be [1, 2, 3, 4, 3] in this case. In real-world scenarios, vocabulary can consist of several millions of words.

The first thing we have to do is to specify a sequence length the LSTM can process. Sequence length refers to the number of tokens the network takes as the context to produce the next word

Considering a specific sequence length, it's important to ensure that all the sequences passed into the LSTM are the same size. That means an LSTM trained with a specific sequence length can only process sequences up to that length. It is said that more sequence length means better prediction so do more computation.

After specifying the sequence length, each of the words in the sequence window needs to be converted to a unique integer token

The next step is the embedding. Embeddings are dense vector representations of words or subwords. Embeddings store semantic meaning about a particular word in a particular context which is often useful for the NLP models to find similarities between words and understand the context. There are two ways in which we can perform embedding in text generation, one way is to use a pre-trained word embedding like Word2Vec by Google or other models, and the other way is to define our own embeddings. Both have their advantages and disadvantages. Creating our embeddings brings extra computational overload since embeddings are learned during training. If we are aiming for accurate and precise text generation, it is preferable to use pre-trained embeddings.

Each of the tokens in the sequence is converted to an embedding of fixed length. The size of the embeddings can vary. Larger embeddings can capture more detailed information about the input data, but they also pose computational challenges. Moreover, a larger embedding dimension can negatively impact the model's performance if the relationships become too complex for the model to learn effectively.

3.2 Vocabulary

The vocabulary consists of all the unique tokens in the training data. We can think of it as a dictionary that stores unique keys and values for each word in the training data. In the case of character-level generation, there is relatively a smaller vocabulary since the vocabulary only has the total number of characters in the language. However, the word-level vocabulary can be really large as the training data increases. A particular law called Heap’s Law where the number of unique words that can be found on a corpus increases exponentially at the start and then decreases. If that is true, a larger training set may have a larger vocabulary.

3.3 Softmax for Prediction

Alright, we've extensively covered the internal workings of the LSTM, such as the hidden states. However, we have yet to explore what occurs in the output layer of the LSTM network. I think it is better to discuss what we need in the output layer rather than what happens. What we need is to get the next token prediction. The best way to do this is to assign a probability to all the unique tokens in the vocabulary so that we can select the token that has the highest probability as the next token in the sequence. This is where softmax is used. Softmax produces a probability distribution that adds up to one. The softmax function can be written as:

The main attraction of softmax is that it squishes the values between 0 and 1 that sum up to one, sigmoid is similar but it does not necessarily add up to one. This is a really important property that we need for token prediction since we don’t need tokens to be independent of each other. By ensuring the probabilities sum to one, softmax provides a proper probability distribution over the possible tokens, reflecting the relative likelihood of each token given the context.

The output layer size is the same as the vocabulary size. The models learn to produce tokens by assigning the probability of each unique token in this vocabulary.

4. Code Implementation

Let’s get to the interesting part, we are going to see the things discussed in practice. The first thing we need is the training data. Here is the link to the dataset that I used: Drive Link

This is a large text file which is more than enough for our task.

4.1 Data Preprocessing and Preparation

Let’s do some data preprocessing. Data preprocessing in the case of text involves removing special characters, numbers, unwanted symbols, etc,

First, let’s import the necessary libraries,
import numpy as np
import pandas as pd
import pickle
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.models import Sequential
import tensorflow as tf

load the data,
with open('/content/drive/MyDrive/Natural-Language-Processing/train.en', 'r') as f:
sentences = f.readlines()

sentences = sentences[:3000]

Here we are using the first 3000 sentences since we are doing this in the Google Colab free platform. There are computational and memory limitations.
import re

def preprocess(sentence):
sentence = sentence.lower()
sentence = re.sub(r'\d+', '', sentence)
sentence = re.sub(r'[^a-zA-Z\s]', '', sentence)
sentence = sentence.strip()
return sentence

sentences = [preprocess(sentence) for sentence in sentences]

Sequence length is an important consideration when training the LSTM. All sentences passed to the LSTM must have the same sequence length, but our sentence corpus doesn’t necessarily meet this requirement. One way to address this is by determining the longest sentence in the corpus and using that length as the sequence length. However, there are two main issues with this approach. Firstly, the training data becomes too complex with a larger sequence length, making it difficult for a smaller model to learn. Secondly, excessive padding tokens are needed in this case. If we use the maximum length sentence, it becomes an outlier, and all other sentences are much shorter. As a result, more padding tokens are required to maintain the sequence length, which can affect the model’s learning and generalization. The model needs to first learn how to avoid padding tokens before addressing the task of generating meaningful tokens.

To address this issue, we can use a technique called Bucketing. Bucketing is the process of grouping sequences of similar lengths into buckets. Here is the code for doing this:
def adjust_sentence_length(sentences, sequence_length):
adjusted_sentences = []
buffer = []

for sentence in sentences:
words = sentence.split()

while len(buffer) >= sequence_length:
adjusted_sentences.append(' '.join(buffer[:sequence_length]))
buffer = buffer[sequence_length:]

if buffer:
adjusted_sentences.append(' '.join(buffer))

# Post-processing to ensure all sentences are around sequence_length
adjusted_sentences_final = []
for sentence in adjusted_sentences:
words = sentence.split()
if len(words) < sequence_length:
if adjusted_sentences_final:
last_sentence_words = adjusted_sentences_final[-1].split()
if len(last_sentence_words) + len(words) <= sequence_length:
adjusted_sentences_final[-1] += ' ' + sentence

return adjusted_sentences_final

sentences = adjust_sentence_length(sentences, 20)

Here we have chosen 20 as the sequence length, you can try different sequence lengths but make sure it is not too long.

Now let’s generate the vocabulary with a dictionary that contains the key as the words and unique integer tokens as values,
def generate_vocab_map(d, sentences):
index = 1
for sentence in sentences:
for word in sentence.split():
if word not in d:
d[word] = index
index += 1

vocab_mapping = {}
generate_vocab_map(vocab_mapping, sentences)

When the model predicts tokens, we need to decode those tokens from integers to their corresponding words for generating text, let’s define a function that helps us to do this,
def convert_int_to_text(int_sequences, vocab_map):
# Create a reverse mapping from index to word
reverse_vocab_map = {index: word for word, index in vocab_map.items()}

text_sequences = []
for int_sequence in int_sequences:
text_sequence = []
for token in int_sequence:
if token in reverse_vocab_map:

text_sequences.append(' '.join(text_sequence))

return text_sequences

Let’s generate the sequence of integer tokens,
sequences = []

def generate_tokens(tokens_list, sentences, vocab_map):
for sentence in sentences:
sentence_tokens = []
for word in sentence.split():
if word in vocab_map:


generate_tokens(sequences, sentences, vocab_mapping)

This code generates a sequence of integer tokens of varying lengths with a maximum of 20 sequence lengths since we have done bucketing. Each of these integer tokens represents a word in a sequence.

The sequences are of varying sizes, in this case, we generally apply padding tokens to maintain the sequence length, but instead what I used is to take each 20 tokens at a time from a list of tokens and append it to a list. So we get a sublist of 20 tokens each. This can be more effective than adding padding tokens,
tokens_list = [sequence for sequence in sequences]
tokens = [token for word_seq in tokens_list for token in word_seq]

input_sequence = []
targets = []
sequence_length = 20

for i in range(len(tokens) - sequence_length):
targets.append(tokens[i + sequence_length])

Now we have the input sequence and targets, one last step remaining is to convert the lists to a NumPy array and convert targets to one-hot encodings which we usually do. We can use the Keras to_categorical function to convert the targets to one-hot encodings.
from tensorflow.keras.utils import to_categorical

X = np.array(input_sequence)
y = np.array(targets)

# Converting to one-hot encodings
y = to_categorical(y)

4.2 Model Building & Training

To create the model architecture, we can use the TensorFlow Sequential API. I have used an input embedding, two LSTM cells, and a batch normalization layer to ensure that the distribution of activation needs to be the same so that the network doesn’t need to learn the distribution each time of training. Here is the network architecture,
model = Sequential()
model.add(Embedding(input_dim=len(vocab_mapping) + 1, output_dim=128))
model.add(LSTM(128, return_sequences=True))
model.add(Dense(128, activation='relu'))
model.add(Dense(units=len(vocab_mapping) + 1, activation='softmax'))

model.compile(loss = "categorical_crossentropy", optimizer = 'adam', metrics = ['accuracy'])

The embeddings take the vocabulary size + 1 as input and produce a dense vector of 128 dimensions. This is learned during training and can reduce the memory requirement for storing the data. Unlike one-hot encoding, which requires a lot of memory if we are doing it for training samples.

This is a relatively simple model. You can experiment with other architectures and more complex models to improve text generation unless you have the computational resources.

Let’s train the model,
model.fit(X, y, batch_size = 32, epochs = 100)

After 100 epochs of training, the model reached an accuracy of 78% on the training set which is not perfect but pretty good.

Here is the function to make the next word prediction,
def predict_next_word(input_text):
input_text = input_text.lower()
word_tokens = input_text.split()
int_tokens = [vocab_mapping[token] for token in word_tokens]

prediction = model.predict([int_tokens])
prediction_idx = np.argmax(prediction)
return convert_int_to_text([[prediction_idx]], vocab_mapping)[0]

First, the input text is converted to lower cases and then converted to int tokens, then the inputs are passed to the model to get a probability distribution over the vocabulary of tokens, and the arg max function in NumPy is used to get the higher probability token from the vocabulary the model predicted. Finally, the integer token predicted is converted to its corresponding word using the function we have defined before.

This function will only predict the next word given a sequence, our goal is to generate text. For this, we can define another function that does this autoregressive generation. It takes the token generated by the model, appends it to the input sequence, and again passes it to the model for the next token.

def generate_text(input_text, n_words):
word_sequence = input_text.split()
context = word_sequence[:]
for _ in range(n_words):
prediction = predict_next_word(' '.join(context))
if len(context) > 20:

return ' '.join(word_sequence)

generate_text("everyone is living happy", 20)


Generated Text:

with you and how it has come even my mother said that may be away your sight and hearing allah has come over

Not bad at all. Even though the model predicts tokens that do not form a cohesive story or accurate descriptions, it still grasps the overall context and is capable of predicting the next probable token. Training the model for a longer period with a larger training set can further improve its performance and result in the generation of more coherent tokens. Here are some potential improvements that can be made:

  • Add more training data
  • Increase the sequence length
  • Larger embedding dimensions
  • Use pre-trained Word Embedding
  • Different LSTM Architectures
  • Longer Epochs

Finally, we can save the trained model using model.save()

That’s it, we have created a text generation model from scratch using LSTM. I believe this article is useful for readers in understanding how text generation works and how LSTM can be utilized for that purpose. The objective of this article is not to construct a model like GPT or other state-of-the-art text generation models; instead, we aim to focus on the general workings of text generation and how a model like LSTM can be employed to comprehend text and predict tokens based on that understanding.

Thanks for reading!