Topic modelling for short texts
Latent Dirichlet Allocation is a very popular model for finding topics in text documents. Crucially, however, LDA is very bad at discovering topics in short documents. This is a problem when you are trying to model topics in Tweets or forum posts, because this sort of content is typically short-form.
The reason LDA is so bad at this, is because it assumes that every document contains multiple topics, while this is usually not the case with shorter documents. We can, instead, assume that each document comes from one underlying cluster that determines word probabilities. This model is called a Dirichlet-Multinomial Mixture, and can be used for clustering text, as well as uncovering topics.
In this tutorial we will look at how you can use DirichletMultinomialMixture to find topics in short texts.
Data Loading
We are going to use a subset of the 20 Newsgroups dataset from scikit-learn. We only load the alt.atheism forum data for now so that it doesn't take a long time to run the algorithm.
from sklearn.datasets import fetch_20newsgroups
corpus = fetch_20newsgroups(
subset="all", remove=("headers", "footers", "quotes"), categories=["alt.atheism"]
).data
Preprocessing
We will use scikit-learn's CountVectorizer to extract a Bag-of-words matrix over our texts. We filter for too frequent or infrequent words, and filter out stop-words as well.
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer(
min_df=10, max_df=0.1, max_features=4000, stop_words="english"
)
X = vectorizer.fit_transform(corpus)
Model fitting
We can now fit a Dirichlet Multinomial Mixture to our data. I chose to use 5 topics, since the results will be easy to display.
from noloox.mixture import DirichletMultinomialMixture
model = DirichletMultinomialMixture(5).fit(X)
Model Interpretation
We plot the top 10 words for each topic on bar charts using Plotly to understand what topics mean.
from plotly.subplots import make_subplots
import numpy as np
fig = make_subplots(rows=1, cols=5, subplot_titles=[f"Topic {i}" for i in range(5)])
vocab = vectorizer.get_feature_names_out()
for i, comp in enumerate(model.components_):
top = np.argsort(-comp)[:10]
fig.add_bar(
y=vocab[top][::-1],
x=comp[top][::-1],
row=1,
col=i+1,
orientation="h",
showlegend=False,
)
fig.show()