How Airtrain Clusters Textual Data

How Airtrain Clusters Textual Data

Josh Bauer·7/18/2024

Why cluster?

Dealing with tons of data is tough, but we often have to do it. When it comes to well-structured numerical or categorical data, we have an assortment of tools to help us figure things out. Scatter plots, heat maps, histograms, violin plots, parallel coordinate plots – the list goes on. But once you start working with unstructured data, like user comments, product descriptions, or news articles, your options shrink a lot.

At Airtrain, we aim to help people make sense of their unstructured text data, whether it is to build datasets for LLM fine-tuning, evaluation sets, or just figuring out how people are using their products. One of the coolest tools we’ve developed for these purposes is semantic clustering. The basic idea is to automatically group your data so that similar pieces are together. To see this in action, check out a clustering we did for 50k articles from CNN and the DailyMail. We didn’t write any special code for news articles—just uploaded the data to our platform and let it do its thing.

clusters

We can see that a good chunk of the articles, around 8%, are about athletic achievements. When we look closer at the “Athletic Achievements” group, we see that “Formula One Racing” makes up about 15% of those articles.

single cluser

So, how do we create these clusters of articles? To explain that, we need to do a quick refresher on embeddings.

Embeddings overview

If you spend any time with folks working on language modeling, you’ll hear the term “embedding” thrown around a lot. But what exactly does it mean? To explain, let's start with some common problems in language modeling and then work our way back to the solution: embeddings.

For many language modeling problems, it's useful to know how “similar” two chunks of text are.

Here are a few examples:

  • Clustering: This is what we’re focused on—we want to group related texts so that texts in the same group share something meaningful. “Similar” = in the same group.
  • Semantic Search: Given a piece of text (like a question), we want to find “similar” texts (like those that might contain the answer).
  • Classification: Like clustering, but you decide in advance what groups to put the texts into based on their similarity.
  • Evaluation: You might have a reference text and want to grade another text on how well it matches the reference (how “similar” they are).

One way to measure similarity is by assigning a single number that shows how “close” two pieces of text are. It would be great to have a function like similarity_score(text_1, text_2) -> float that gives lower numbers for similar texts and higher numbers for less similar ones. For example:

LabelText
AResearchers hope advancements in computer science will unlock new climate modeling capabilities.
BHomer Simpson once blew up a nuclear power facility.
CLisa Simpson once invented a super-intelligent AI.
DNuclear power is considered by some to be a key element towards addressing the climate crisis.
EThe properties of modern GPUs are ideal for AI as well as parallel simulations for complex systems.

‍ If we consider each pair of statements here, we can get the following table:

Label PairProximityProximity justification
A ↔ Bfar (high #)
A ↔ Cfar (high #)
A ↔ Dclose (low #)Both discuss solutions for the climate crisis.
A ↔ Eclose (low #)Both mention scientific simulations and modern computing.
B ↔ Cclose (low #)Both discuss the activities of Simpsons characters.
B ↔ Dclose (low #)Both discuss nuclear power.
B ↔ Efar (high #)
C ↔ Dfar (high #)
C ↔ Eclose (low #)Both mention AI development.
D ↔ Efar (high #)

Our question now is how we could define our similarity_score(text_1, text_2) -> float function. As you can imagine, a table like this would grow really fast as you add more text. Directly mapping pairs of texts to a number could get pretty tricky. Instead, we could assign a number to each text (like get_text_location(text) -> float) and then figure out the proximity scores by looking at the difference between these numbers: distance = abs(get_text_location(text_a) - get_text_location(text_b)). Let’s look at the distance relationships for all the pairs we have and see how we might assign locations on a number line.

distance

Uh oh…things were going well until we tried to place “A.” There’s no good way to place it to satisfy all the relationships we want: it needs to be close to D (both talk about climate) and to E (both talk about simulation and computing). But D and E are far from each other. If we put A between them, it ends up close to B and C, which we don’t want. One way to solve this is to spread the points out vertically as well as horizontally, instead of just horizontally.

pentagram

The shortest 5 lines are in blue, and the longer ones are in black. This setup perfectly shows the relationships we want! All the similar statements have shorter lines (smaller distances), and the less similar ones have longer lines (larger distances).

By assigning each piece of text a location in two dimensions, we captured more complex relationships between the texts than we could with just one dimension. This process of assigning a text to a location on a number line, 2D plane, 3D space, or even “n” dimensions is called an “embedding” of the text. The geometric space where we lay out these points is called the “embedding space.” Just like moving from 1D to 2D lets us show more complex relationships, adding more dimensions gives us even more flexibility. In practice, modern language modeling applications use hundreds or even thousands of dimensions, allowing for very rich relational representations.

It’s important to highlight a couple of things about our particular embedding. First, the exact coordinates of any point don’t really matter; what we care about are the distances between pairs of points. These distances don’t depend on the actual coordinates of any specific point. We could rotate or move the origin (0, 0), and the distances we care about would stay the same. We also only care about relative distances (like “Are B & C closer than A & C?”), rather than exact measurements (“Is the distance between B & C 42 or 1000?”).

pentagram axis

These same properties hold in higher dimensions too, which means it’s usually pretty pointless to look at any particular embedding vector (the coordinates of the text in the embedding space) by itself. Given this flexibility, a common choice is to define your embedding function (get_text_location(text) → EmbeddingCoordinates) so that the coordinates are always a distance of 1 from the origin. This means we lose 1 degree of freedom (the radial direction), but it also has some advantages. One big benefit is that it sets a scale for the distances—the farthest apart two points can be with these rules is 2 (on opposite sides of the n-dimensional unit sphere). For the extra-curious readers, it also means that cosine similarity (which is easier to compute) has a simple relationship to the more conventional definition of distance (Euclidean distance).

Another thing we haven’t covered yet is how you define the embedding function in the first place. As you might guess, this usually involves training a machine-learning model to map text to coordinates in the embedding space, and there are several techniques to do this. You can check out some embedding models on this leaderboard, which shows how different models perform on a popular embedding benchmark.

Clustering

With that understanding in place, we’re ready to tackle clustering. First, we’ll convert all our pieces of text into points in the embedding space. Then, we’ll look for groups of points that are close to each other. For example, if we look at a 2D projection of the embeddings for some articles in the “Entertainment Recognition” cluster, we can see they are grouped within a specific region of this 2D representation of the embedding space.

entertainment

You might wonder why those green “Entertainment Recognition” articles are grouped together, but the other blurred points nearby aren’t in the same group. It’s important to remember that these points have coordinates in a high-dimensional space (at least hundreds of dimensions), but we’re viewing them in just 2 dimensions. This means some points that look close in this 2D view might not be close in the full embedding space. It’s like how the red and black points in this image of an elephant might look close from above (one possible 2D view of the points in 3D), but when viewed from the side (another 2D view of the 3D data), you can see they are far apart.

bev mamoth

mamoth

Images from https://pair-code.github.io/understanding-umap/

So, how do we find groups of points that are close to one another? Luckily, clustering points in high-dimensional spaces is a well-explored problem. Two popular algorithms for this are KMeans and HDBScan. Each has its pros and cons. For example, KMeans assigns every point to a cluster, while HDBScan labels some points as noise. On the other hand, HDBScan can capture more complex cluster shapes, whereas KMeans expects “spherical” clusters. This matters because your data might have clusters with various shapes, like the tusks or legs of the elephant above.

At Airtrain, we use a combination of these algorithms to capture clusters with complex shapes while also keeping the maximum cluster size in check and limiting the amount labeled as noise.

Cluster Labeling

While clustering is a well-studied area of data science, a raw clustering algorithm just gives you an assignment of each point (each row of text data) to a cluster with a numeric ID. But saying you have 719 points in “cluster number 42” isn’t very helpful when you’re trying to understand the big picture of your dataset. You’d need to dive into “cluster number 42” yourself to figure out what the data in that cluster has in common. You could look at a few rows, identify a theme, come up with a name, and then label “cluster number 42” with something meaningful like “Formula One Racing.”

Thankfully, in the age of LLMs, we can automate this labeling task by handing it off to a model. That’s exactly what we do at Airtrain. To help the model come up with a good, meaningful name, we first have it describe the examples from each cluster before deciding on a name. If you want to get fancy, you can consider this part of “Chain of Thought” prompting.

Representing Hierarchy In Topics

Remember our goal from the beginning of the article? We organized the CNN data into broad, general topics (like “Athletic Achievements”) with more detailed, smaller topics within those (like “Formula One Racing” within “Athletic Achievements”). This approach is great for exploring datasets because you can start by getting a sense of the major categories, then “drill down” to more specific ones, and finally “drill down” further to look at individual rows of data.

But where does this hierarchy come from? So far, we’ve only talked about assigning one group to each row, not about breaking those groups down further or combining narrow groups into broader ones. At Airtrain, we take a “bottoms up” approach. We start with a large number of narrowly focused clusters (as described above) and then combine these narrow clusters into more general categories. Creating these broader clusters is similar to how we created the narrower ones. But instead of using the raw rows of data, we use the names and descriptions of the narrow clusters as our inputs.

An Example Use-Case

Now that we've got these hierarchical clusters, let's put them to work on a real-world problem. Imagine we want to use this CNN dataset to train a model that extracts highlights from news articles. Picking high-quality training data is crucial, so we need to hunt down any low-quality stuff that might mess things up. The broad "Miscellaneous" group seems like a good place to start looking for data that doesn't quite fit. In Airtrain, we can click on that slice of the pie to expand all the sub-clusters within it. When we do that and hover over the sub-clusters, we spot "News Quizzes" - which look like it might contain data that is not useful for our training objective.

news quizzes

We can click on that sub-cluster to add it as a data filter and start digging into the corresponding rows.

filter

Sure enough, these rows are pretty different from the "article about a unified subject" type we're aiming for with our model. Now we can kick these rows out of our dataset and keep refining it to get ready for tuning.

AI Data Platform

A comprehensive AI platform

Dataset Curation

Generate high-quality datasets.

LLM Fine-Tuning

Customize LLMs to your specific use case.

LLM Playground

Vibe-check 30+ SOTA LLMs at once.

LLM Evaluation

Compare LLMs on your entire eval set.

Accelerate your AI workflows with Airtrain's comprehensive suite of tools. From dataset curation to LLM fine-tuning and evaluation.

Unlock your data, control your AI.