Skip to content

Late Interaction Topic Models

Late interaction, or multi-vector models use token representations from a Sentence Transformer before pooling them all together into a single document embedding. This can be particularly useful for clustering models, as they, by default assign one topic to a single document, but when accessing token representations, can assign topics on a per-token basis.

Info

There are currently no native late-interaction models in Turftopic, meaning models that explicitly model token representations in the context of a document. We are currently working on implementing such models, but for the time being, wrappers are included, that can force regular models to use embeddings of higher granularity. Visualization utilities are also on the way.

Encoding Tokens, and Ragged Array Manipulation

Turftopic provides a convenience class for encoding documents on a token-level using Sentence Transformers instead of pooling them together into document embeddings. In order to initialize an encoder, load LateSentenceTransformer, and specify which model you would like to use:

Tip

While you could use any encoder model with LateSentenceTransformer, we recommend that you stick to ones that have mean pooling, and normalize embeddings. This is because in these models, you can be sure that the pooled document embeddings and the token embeddings will be in the same semantic space.

Token Embeddings

from turftopic.late import LateSentenceTransformer

documents = ["This is a text", "This is another but slightly longer text"]

encoder = LateSentenceTransformer("all-MiniLM-L6-v2")
token_embeddings, offsets = encoder.encode_tokens(documents)
print(token_embeddings)
print(offsets)
[
  array([[-0.01135089,  0.04170538,  0.00379963, ...,  0.01383126,
        -0.00274855, -0.05360783],
        ...
       [ 0.05069249,  0.03840942, -0.03545087, ...,  0.03142243,
         0.01929936, -0.09216172]],
        shape=(6, 384), dtype=float32),
  array([[-0.00047079,  0.03402771,  0.00037086, ...,  0.0228903 ,
        -0.01734272, -0.04073172],
       ...,
       [-0.02586325,  0.03737643,  0.02260585, ...,  0.05613737,
        -0.01032581, -0.03799873]], shape=(9, 384), dtype=float32)
]
[[(0, 0), (0, 4), (5, 7), (8, 9), (10, 14), (0, 0)], [(0, 0), (0, 4), (5, 7), (8, 15), (16, 19), (20, 28), (29, 35), (36, 40), (0, 0)]]

As you can see, encode_tokens returns two arrays, one of them being the token embeddings. This is a ragged array, where longer document can have more embeddings. offsets contains a list of tuples for each document, where the first element of the tuple is the start character of the given token, and the second element is the end character.

Rolling Window Embeddings

You can also pool these embeddings over a rolling window of tokens. This way, you still represent your document with multiple vectors, but don't need to model each token individually:

window_embeddings, window_offsets = encoder.encode_windows(documents, window_size=5, step_size=4)
for doc_emb, doc_off in zip(window_embeddings, window_offsets):
    print(doc_emb.shape, doc_off)
(2, 384) [(0, 14), (10, 0)]
(3, 384) [(0, 19), (16, 0), (0, 0)]

Ragged array manipulation

These ragged datastructures are hard to deal with, especially when using array operations, so we include convenience functions for manipulating them: flatten_repr flattens the ragged array into a single large array, and returns the length of each sub-array:

from turftopic.late import flatten_repr, unflatten_repr

flat_token_embeddings, lengths = flatten_repr(token_embeddings)
print(flat_token_embeddings.shape) 
# (15, 384)

unflatten_repr will turn a flattened representation array into a ragged array:

token_embeddings = unflatten_repr(flat_token_embeddings, lengths)

pool_flat will pool a document representations in a flattened array using a given aggregation function:

import numpy as np
from turftopic.late import pool_flat

pooled = pool_flat(flat_token_embeddings, lengths, agg=np.nanmean)
print(pooled.shape)
# (2, 384)

Turning Regular Models into Multi-Vector Models

The LateWrapper class can turn your regular topic models into ones that can utilize windowed or token-level embeddings. Here's how LateWrapper works:

  1. It encodes documents at a token or window-level based on its parameters.
  2. It flattens the embedding array, and feeds the this into the topic model, along with the token/window text.
  3. It unflattens the output of the topic model (doc_topic_matrix) into a ragged array, where you get topic importance for each token.
  4. [OPTIONAL] It pools token-level topic content on the document level, so that you get one document-topic vector for each document instead of each token.

Let's see how this works in practice, and create a Topeax model that uses windowed embeddings instead of document-level embeddings:

from sklearn.datasets import fetch_20newsgroups
from turftopic import Topeax
from turftopic.late import LateWrapper, LateSentenceTransformer

corpus = fetch_20newsgroups(subset="all", categories=["alt.atheism"]).data

model = LateWrapper(
    Topeax(encoder=LateSentenceTransformer("all-MiniLM-L6-v2")),
    window_size=50, # If we don't specify window size, it will use token-level embeddings
    step_size=40, # Since the step size is smaller than the window, we will get overlapping windows
)
doc_topic_matrix, offsets = model.fit_transform(corpus)
model.print_topics()
Topic ID Highest Ranking
0 morality, moral, morals, immoral, objective, behavior, instinctive, species, inherent, animals
1 matthew, luke, bible, text, passages, mormon, texts, translations, copy, john
2 atheism, agnostics, atheist, beliefs, belief, faith, contradictory, believers, contradictions, theists
3 punishment, cruel, abortion, penalty, death, constitution, homosexuality, painless, capital, punish
4 war, arms, invaded, gulf, hussein, civilians, military, kuwait, peace, sell
5 islam, islamic, muslim, qur, muslims, imams, rushdie, quran, koran, khomeini

The document-topic matrix, we created, is now a ragged array and contains document-topic proportions for each window in a document. Let's see what this means in practice for the first document in our corpus:

import pandas as pd

# We select document 0, then collect all information into a dataframe:
window_topic_matrix = doc_topic_matrix[0]
window_offs = offsets[0]
document = corpus[0]
# We extract the text for each window based on the offsets
window_text = [document[window_start: window_end] for window_start, window_end in window_offs]
df = pd.DataFrame(window_topic_matrix, index=window_text, columns=model.topic_names)
print(df)

                                                    0_morality_moral_morals_immoral  1_matthew_luke_bible_text  ...  4_war_arms_invaded_gulf  5_islam_islamic_muslim_qur
From: acooper@mac.cc.macalstr.edu (Turin Turamb...                         0.334267               1.287207e-13  ...             2.626869e-26                1.459101e-04
alester College\nLines: 55\n\nIn article <C5sA2...                         0.360400               8.898302e-14  ...             3.290858e-26                1.382718e-04
u (Mike Cobb) writes:\n> I guess I'm delving in...                         0.847002               5.002921e-22  ...             4.852574e-41                3.141366e-07
this you just have a spiral.  What\nwould then ...                         0.848413               5.819050e-22  ...             8.139559e-41                3.286224e-07
, even though this would hardly seem moral.  Fo...                         0.863685               1.272204e-21  ...             2.823941e-41                2.815930e-07
whatever helps this goal is\n"moral", whatever ...                         0.864913               1.584558e-21  ...             5.780971e-41                3.003952e-07
a "hyper-morality" to apply to just the methods...                         0.865558               1.919885e-21  ...             1.251694e-40                3.231265e-07
not doing something because it is\n> a personal...                         0.868360               2.951441e-21  ...             3.085662e-40                3.494368e-07
we only consider something moral or immoral if ...                         0.872827               5.444738e-21  ...             4.708349e-40                3.580695e-07
here we have a way to discriminate\nmorals.  I ...                         0.876951               1.021014e-20  ...             3.486096e-40                3.411401e-07
enough and\nlistened to the arguments, I could ...                         0.878680               2.302363e-20  ...             5.866410e-40                3.565728e-07
.  Or, as you brought out,\n> if whatever is ri...                         0.878953               3.004052e-20  ...             5.977738e-40                3.566668e-07
> *******************************                                          0.647793               5.664651e-17  ...             1.805073e-19                4.612731e-04

C-Top2Vec

Contextual Top2Vec (Angelov and Inkpen, 2024) is a late-interaction topic model, that uses windowed representations. The model is essentially the same as wrapping a regular Top2vec model in LateWrapper, but we provide a convenience class in Turftopic, so that it's easy for you to initialize this model. It comes pre-loaded with the following features:

  • Same hyperparameters as in Angelov and Inkpen (2024)
  • Phrase-vectorizer that finds regular phrases based on PMI
  • LateSentenceTransformer by default, you can specify any model.

Our implementation is much more flexible than the original top2vec package, and you might be able to use much more powerful or novel embedding models.

from turftopic import CTop2Vec

model = CTop2Vec(n_reduce_to=5)
doc_topic_matrix = model.fit_transform(corpus)

model.print_topics()
Topic ID Highest Ranking
-1 caused atheism organization, genocide caused atheism, atheism organization, atheism, subject political atheists, alt atheism, caused atheism, political atheists organization, subject amusing atheists, amusing atheists
166 atheists organization, political atheists organization, christian morality organization, caused atheism organization, morality organization, atheism organization, atheists organization california, subject amusing atheists, cwru edu article, alt atheism
172 biblical, read bible, caused atheism, agnostics, caused atheism organization, atheists agnostics, christianity, alt atheism, atheism, christian morality organization
173 objective morality, morality, subject christian morality, christian morality, natural morality, say christian morality, morality organization, christian morality organization, behavior moral, moral
175 atheism, atheism organization, caused atheism organization, atheists agnostics, caused atheism, subject political atheists, alt atheism, genocide caused atheism, subject amusing atheists, amusing atheists
176 rushdie islamic law, subject rushdie islamic, islamic genocide, islamic law, genocide caused atheism, subject islamic, islamic law organization, islamic genocide organization, rushdie islamic, islamic authority

You might also observe that the output of this model is a regular document-topic matrix, and isn't ragged.

print(doc_topic_matrix.shape)
# (1024, 6)

This is because this way the model has the same API, as other Turftopic models, and works the same way as the top2vec package, making migration easier.

API Reference

Encoder

turftopic.late.LateSentenceTransformer

Bases: SentenceTransformer

SentenceTransformer model that can produce token and window-level embeddings. Its output can be used by topic models that can use multi-vector document representations.

Warning

This is not checked yet in the library, but we recommend that you use SentenceTransformers that are a) Mean pooled b) L2 Normalized This will guarrantee that the token/window embeddings are in the same embedding space as the documents.

Source code in turftopic/late.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class LateSentenceTransformer(SentenceTransformer):
    """SentenceTransformer model that can produce token and window-level embeddings.
    Its output can be used by topic models that can use multi-vector document representations.

    !!! warning
        This is not checked yet in the library,
        but we recommend that you use SentenceTransformers that are
        a) **Mean pooled**
        b) **L2 Normalized**
        This will guarrantee that the token/window embeddings are in the same embedding space as the documents.
    """

    has_used_token_level = False

    def encode(
        self, sentences: Union[str, list[str], np.ndarray], *args, **kwargs
    ):
        if not self.has_used_token_level:
            warnings.warn(
                "Encoder is contextual but topic model is not using contextual embeddings. Perhaps you wanted to use another topic model."
            )
        return super().encode(sentences, *args, **kwargs)

    def _encode_tokens(
        self,
        texts,
        batch_size=32,
        show_progress_bar=True,
    ) -> tuple[list[np.ndarray], list[Offsets]]:
        """
        Returns
        -------
        token_embeddings: list[np.ndarray]
            Embedding matrix of tokens for each document.
        offsets: list[list[tuple[int, int]]]
            Start and end character of each token in each document.
        """
        self.has_used_token_level = True
        token_embeddings = []
        offsets = []
        for start_index in trange(
            0,
            len(texts),
            batch_size,
            desc="Encoding batches...",
        ):
            batch = texts[start_index : start_index + batch_size]
            features = self.tokenize(batch)
            with torch.no_grad():
                output_features = self.forward(features)
            n_tokens = output_features["attention_mask"].sum(axis=1)
            # Find first nonzero elements in each document
            # The document could be padded from the left, so we have to watch out for this.
            start_token = torch.argmax(
                (output_features["attention_mask"] > 0).to(torch.long), axis=1
            )
            end_token = start_token + n_tokens
            for i_doc in range(len(batch)):
                _token_embeddings = (
                    output_features["token_embeddings"][
                        i_doc, start_token[i_doc] : end_token[i_doc], :
                    ]
                    .float()
                    .numpy(force=True)
                )
                _n = _token_embeddings.shape[0]
                # We extract the character offsets and prune it at the maximum context length
                _offsets = self.tokenizer(
                    batch[i_doc], return_offsets_mapping=True, verbose=False
                )["offset_mapping"][:_n]
                token_embeddings.append(_token_embeddings)
                offsets.append(_offsets)
        return token_embeddings, offsets

    def encode_tokens(
        self,
        sentences: list[str],
        batch_size: int = 32,
        show_progress_bar: bool = True,
    ):
        """Produces contextual token embeddings over all documents.

        Parameters
        ----------
        sentences: list[str]
            Documents to encode contextually.
        batch_size: int, default 32
            Size of the batch of document to encode at once.
        show_progress_bar: bool, default True
            Indicates whether a progress bar should be displayed when encoding.

        Returns
        -------
        token_embeddings: list[np.ndarray]
            Embedding matrix of tokens for each document.
        offsets: list[list[tuple[int, int]]]
            Start and end character of each token in each document.
        """
        # This is needed because the above implementation does not normalize embeddings,
        # which normally happens to document embeddings.
        token_embeddings, offsets = self._encode_tokens(
            sentences,
            batch_size=batch_size,
            show_progress_bar=show_progress_bar,
        )
        token_embeddings = [normalize(emb) for emb in token_embeddings]
        return token_embeddings, offsets

    def encode_windows(
        self,
        sentences: list[str],
        batch_size: int = 32,
        window_size: int = 50,
        step_size: int = 40,
        show_progress_bar: bool = True,
    ):
        """Produces contextual embeddings for a sliding window of tokens similar to C-Top2Vec.

        Parameters
        ----------
        sentences: list[str]
            Documents to encode contextually.
        batch_size: int, default 32
            Size of the batch of document to encode at once.
        window_size: int, default 50
            Size of the sliding window.
        step_size: int, default 40
            Step size of the window.
            If step_size < window_size, windows will overlap.
            If step_size == window_size, then windows are separate.
            If step_size > window_size, there will be gaps between the windows.
            In this case, we throw a warning, as this is probably unintended behaviour.
        show_progress_bar: bool, default True
            Indicates whether a progress bar should be displayed when encoding.

        Returns
        -------
        window_embeddings: list[np.ndarray]
            Embedding matrix of windows for each document.
        offsets: list[list[tuple[int, int]]]
            Start and end character of each token in each document.
        """
        token_embeddings, token_offsets = self._encode_tokens(
            sentences,
            batch_size=batch_size,
            show_progress_bar=show_progress_bar,
        )
        window_embeddings = []
        window_offsets = []
        for emb, offs in zip(token_embeddings, token_offsets):
            _offsets = []
            _embeddings = []
            for start_index in range(0, len(emb), step_size):
                end_index = start_index + window_size
                window_emb = np.mean(emb[start_index:end_index], axis=0)
                off = offs[start_index:end_index]
                _embeddings.append(window_emb)
                _offsets.append((off[0][0], off[-1][1]))
            window_embeddings.append(normalize(np.stack(_embeddings)))
            window_offsets.append(_offsets)
        return window_embeddings, window_offsets

encode_tokens(sentences, batch_size=32, show_progress_bar=True)

Produces contextual token embeddings over all documents.

Parameters:

Name Type Description Default
sentences list[str]

Documents to encode contextually.

required
batch_size int

Size of the batch of document to encode at once.

32
show_progress_bar bool

Indicates whether a progress bar should be displayed when encoding.

True

Returns:

Name Type Description
token_embeddings list[ndarray]

Embedding matrix of tokens for each document.

offsets list[list[tuple[int, int]]]

Start and end character of each token in each document.

Source code in turftopic/late.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def encode_tokens(
    self,
    sentences: list[str],
    batch_size: int = 32,
    show_progress_bar: bool = True,
):
    """Produces contextual token embeddings over all documents.

    Parameters
    ----------
    sentences: list[str]
        Documents to encode contextually.
    batch_size: int, default 32
        Size of the batch of document to encode at once.
    show_progress_bar: bool, default True
        Indicates whether a progress bar should be displayed when encoding.

    Returns
    -------
    token_embeddings: list[np.ndarray]
        Embedding matrix of tokens for each document.
    offsets: list[list[tuple[int, int]]]
        Start and end character of each token in each document.
    """
    # This is needed because the above implementation does not normalize embeddings,
    # which normally happens to document embeddings.
    token_embeddings, offsets = self._encode_tokens(
        sentences,
        batch_size=batch_size,
        show_progress_bar=show_progress_bar,
    )
    token_embeddings = [normalize(emb) for emb in token_embeddings]
    return token_embeddings, offsets

encode_windows(sentences, batch_size=32, window_size=50, step_size=40, show_progress_bar=True)

Produces contextual embeddings for a sliding window of tokens similar to C-Top2Vec.

Parameters:

Name Type Description Default
sentences list[str]

Documents to encode contextually.

required
batch_size int

Size of the batch of document to encode at once.

32
window_size int

Size of the sliding window.

50
step_size int

Step size of the window. If step_size < window_size, windows will overlap. If step_size == window_size, then windows are separate. If step_size > window_size, there will be gaps between the windows. In this case, we throw a warning, as this is probably unintended behaviour.

40
show_progress_bar bool

Indicates whether a progress bar should be displayed when encoding.

True

Returns:

Name Type Description
window_embeddings list[ndarray]

Embedding matrix of windows for each document.

offsets list[list[tuple[int, int]]]

Start and end character of each token in each document.

Source code in turftopic/late.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def encode_windows(
    self,
    sentences: list[str],
    batch_size: int = 32,
    window_size: int = 50,
    step_size: int = 40,
    show_progress_bar: bool = True,
):
    """Produces contextual embeddings for a sliding window of tokens similar to C-Top2Vec.

    Parameters
    ----------
    sentences: list[str]
        Documents to encode contextually.
    batch_size: int, default 32
        Size of the batch of document to encode at once.
    window_size: int, default 50
        Size of the sliding window.
    step_size: int, default 40
        Step size of the window.
        If step_size < window_size, windows will overlap.
        If step_size == window_size, then windows are separate.
        If step_size > window_size, there will be gaps between the windows.
        In this case, we throw a warning, as this is probably unintended behaviour.
    show_progress_bar: bool, default True
        Indicates whether a progress bar should be displayed when encoding.

    Returns
    -------
    window_embeddings: list[np.ndarray]
        Embedding matrix of windows for each document.
    offsets: list[list[tuple[int, int]]]
        Start and end character of each token in each document.
    """
    token_embeddings, token_offsets = self._encode_tokens(
        sentences,
        batch_size=batch_size,
        show_progress_bar=show_progress_bar,
    )
    window_embeddings = []
    window_offsets = []
    for emb, offs in zip(token_embeddings, token_offsets):
        _offsets = []
        _embeddings = []
        for start_index in range(0, len(emb), step_size):
            end_index = start_index + window_size
            window_emb = np.mean(emb[start_index:end_index], axis=0)
            off = offs[start_index:end_index]
            _embeddings.append(window_emb)
            _offsets.append((off[0][0], off[-1][1]))
        window_embeddings.append(normalize(np.stack(_embeddings)))
        window_offsets.append(_offsets)
    return window_embeddings, window_offsets

Wrapper

turftopic.late.LateWrapper

Bases: ContextualModel, TransformerMixin

Wraps existing Turftopic model so that they can accept and create multi-vector document representations.

Warning

The model HAS TO HAVE a late interaction encoder model (e.g. LateSentenceTransformer)

Parameters:

Name Type Description Default
model TransformerMixin

Turftopic model to turn into late-interaction model.

required
batch_size Optional[int]

Batch size of the transformer.

32
window_size Optional[int]

Size of the sliding window to average tokens over. If None, documents will be represented at a token level.

None
step_size Optional[int]

Step size of the window. If (step_size == None) or (step_size == window_size), then windows are separate. If step_size < window_size, windows will overlap. If step_size > window_size, there will be gaps between the windows. In this case, we throw a warning, as this is probably unintended behaviour.

None
pooling Optional[Callable]

Indicates whether and how to pool document-topic matrices. If None, multi-vector topic proportions are returned in a ragged array. If Callable, multiple vectors are averaged with the callable in each document. You could for example take the mean by specifying pooling=np.nanmean.

None
Source code in turftopic/late.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
class LateWrapper(ContextualModel, TransformerMixin):
    """Wraps existing Turftopic model so that they can accept and create
    multi-vector document representations.

    !!! warning
        The model HAS TO HAVE a late interaction encoder model
        (e.g. `LateSentenceTransformer`)

    Parameters
    ----------
    model
        Turftopic model to turn into late-interaction model.
    batch_size: int, default 32
        Batch size of the transformer.
    window_size: int, default None
        Size of the sliding window to average tokens over.
        If None, documents will be represented at a token level.
    step_size: int, default None
        Step size of the window.
        If (step_size == None) or (step_size == window_size), then windows are separate.
        If step_size < window_size, windows will overlap.
        If step_size > window_size, there will be gaps between the windows.
        In this case, we throw a warning, as this is probably unintended behaviour.
    pooling: Callable, default None
        Indicates whether and how to pool document-topic matrices.
        If None, multi-vector topic proportions are returned in a ragged array.
        If Callable, multiple vectors are averaged with the callable in each document.
        You could for example take the mean by specifying `pooling=np.nanmean`.
    """

    def __init__(
        self,
        model: TransformerMixin,
        batch_size: Optional[int] = 32,
        window_size: Optional[int] = None,
        step_size: Optional[int] = None,
        pooling: Optional[Callable] = None,
    ):
        self.model = model
        self.batch_size = batch_size
        self.pooling = pooling
        self.window_size = window_size
        self.step_size = step_size

    def encode_late(
        self, raw_documents: list[str]
    ) -> tuple[np.ndarray, list[Offsets]]:
        if self.window_size is None:
            embeddings, offsets = self.model.encoder.encode_tokens(
                raw_documents, batch_size=self.batch_size
            )
            return embeddings, offsets
        # If the window_size is specified, but not step_size, we set the step size to the window size
        # Thereby getting non-overlapping windows
        step_size = (
            self.window_size if self.step_size is None else self.step_size
        )
        embeddings, offsets = self.model.encoder.encode_windows(
            raw_documents,
            batch_size=self.batch_size,
            window_size=self.window_size,
            step_size=step_size,
        )
        return embeddings, offsets

    def transform(
        self,
        raw_documents: list[str],
        embeddings: list[np.ndarray] = None,
        offsets: list[Offsets] = None,
    ):
        if (embeddings is None) or (offsets is None):
            embeddings, offsets = self.encode_late(raw_documents)
        flat_embeddings, lengths = flatten_repr(embeddings)
        chunks = get_document_chunks(raw_documents, offsets)
        out_array = self.model.transform(chunks, embeddings=flat_embeddings)
        if self.pooling is None:
            return unflatten_repr(out_array, lengths), offsets
        else:
            return pool_flat(out_array, lengths)

    def fit_transform(
        self,
        raw_documents: list[str],
        y=None,
        embeddings: list[np.ndarray] = None,
        offsets: list[Offsets] = None,
    ):
        if (embeddings is None) or (offsets is None):
            embeddings, offsets = self.encode_late(raw_documents)
        flat_embeddings, lengths = flatten_repr(embeddings)
        chunks = get_document_chunks(raw_documents, offsets)
        out_array = self.model.fit_transform(
            chunks, embeddings=flat_embeddings
        )
        if self.pooling is None:
            return unflatten_repr(out_array, lengths), offsets
        else:
            return pool_flat(out_array, lengths)

    @property
    def components_(self):
        return self.model.components_

    @property
    def hierarchy(self):
        return self.model.hierarchy

    @property
    def topic_names(self):
        return self.model.topic_names

    @property
    def classes_(self):
        return self.model.classes_

    @property
    def vectorizer(self):
        return self.model.vectorizer

Utility functions

turftopic.late.flatten_repr(repr)

Flattens ragged array to normal array.

Parameters:

Name Type Description Default
repr list[ndarray]

Ragged representation array.

required

Returns:

Name Type Description
flat_repr ndarray

Flattened representation array.

lengths list[int]

Length of each document in the corpus.

Source code in turftopic/late.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def flatten_repr(
    repr: list[np.ndarray],
) -> tuple[np.ndarray, Lengths]:
    """Flattens ragged array to normal array.

    Parameters
    ----------
    repr: list[ndarray]
        Ragged representation array.

    Returns
    -------
    flat_repr: ndarray
        Flattened representation array.
    lengths: list[int]
        Length of each document in the corpus.
    """
    lengths = [r.shape[0] for r in repr]
    return np.concatenate(repr, axis=0), lengths

turftopic.late.unflatten_repr(flat_repr, lengths)

Unflattens flat array to ragged array.

Parameters:

Name Type Description Default
flat_repr ndarray

Flattened representation array.

required
lengths Lengths

Length of each document in the corpus.

required

Returns:

Name Type Description
repr list[ndarray]

Ragged representation array.

Source code in turftopic/late.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def unflatten_repr(
    flat_repr: np.ndarray, lengths: Lengths
) -> list[np.ndarray]:
    """Unflattens flat array to ragged array.

    Parameters
    ----------
    flat_repr: ndarray
        Flattened representation array.
    lengths: list[int]
        Length of each document in the corpus.

    Returns
    -------
    repr: list[ndarray]
        Ragged representation array.

    """
    repr = []
    start_index = 0
    for length in lengths:
        repr.append(flat_repr[start_index : start_index + length])
        start_index += length
    return repr

turftopic.late.pool_flat(flat_repr, lengths, agg=np.nanmean)

Pools vectors within documents using the agg function.

Parameters:

Name Type Description Default
flat_repr ndarray

Flattened document representations.

required
lengths Lengths

Number of tokens in each document.

required

Returns:

Type Description
ndarray of shape (n_documents, n_dims)

Pooled representation for each document.

Source code in turftopic/late.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def pool_flat(flat_repr: np.ndarray, lengths: Lengths, agg=np.nanmean):
    """Pools vectors within documents using the agg function.

    Parameters
    ----------
    flat_repr: ndarray of shape (n_total_tokens, n_dims)
        Flattened document representations.
    lengths: Lengths
        Number of tokens in each document.

    Returns
    -------
    ndarray of shape (n_documents, n_dims)
        Pooled representation for each document.
    """
    pooled = []
    start_index = 0
    for length in lengths:
        pooled.append(
            agg(flat_repr[start_index : start_index + length], axis=0)
        )
        start_index += length
    return np.stack(pooled)

turftopic.late.get_document_chunks(raw_documents, offsets)

Extracts text chunks from documents based on token/window offsets.

Parameters:

Name Type Description Default
raw_documents list[str]

Text documents.

required
offsets list[Offsets]

Offsets returned when encoding.

required

Returns:

Type Description
list[str]

Text chunks of tokens/windows in the documents.

Source code in turftopic/late.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def get_document_chunks(
    raw_documents: list[str], offsets: list[Offsets]
) -> list[str]:
    """Extracts text chunks from documents based on token/window offsets.

    Parameters
    ----------
    raw_documents: list[str]
        Text documents.
    offsets: list[Offsets]
        Offsets returned when encoding.

    Returns
    -------
    list[str]
        Text chunks of tokens/windows in the documents.
    """
    chunks = []
    for doc, _offs in zip(raw_documents, offsets):
        for start_char, end_char in _offs:
            chunks.append(doc[start_char:end_char])
    return chunks