news

Here is an intuitive explanation of how sparse autoencoders work

2024-08-05

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina



Machine Heart Report

Editor: Panda

In short: matrix → ReLU activation → matrix

Sparse Autoencoders (SAEs) are an increasingly common tool for interpreting machine learning models (although SAEs have been around since 1997).

Machine learning models and LLMs are becoming increasingly powerful and useful, but they are still black boxes, and we do not understand how they accomplish their tasks. Understanding how they work should be helpful.

SAE helps us decompose the model's computations into understandable components. Recently, Adam Karvonen, an LLM interpretability researcher, published a blog post that intuitively explains how SAE works.

The interpretability challenge

The most natural building blocks of a neural network are individual neurons. Unfortunately, individual neurons do not easily correspond to individual concepts, such as academic citations, English conversations, HTTP requests, and Korean text. In neural networks, concepts are represented by combinations of neurons, which is called superposition.

This happens because many variables in the world are naturally sparse.

For example, a famous person’s birthplace may appear in less than one in a billion training tokens, but a modern LLM can still learn this fact and a lot of other knowledge about the world. The number of individual facts and concepts in the training data is greater than the number of neurons in the model, which is probably why superposition occurs.

Sparse Autoencoders (SAE) are a technique that has become increasingly used in recent times to decompose neural networks into understandable components. The design of SAE is inspired by the sparse coding hypothesis in neuroscience. Today, SAE has become one of the most promising tools for understanding artificial neural networks. SAE is similar to a standard autoencoder.

A conventional autoencoder is a neural network that compresses and reconstructs input data.

For example, if the input is a 100-dimensional vector (a list of 100 values); the autoencoder first passes the input through an encoder layer to compress it into a 50-dimensional vector, and then feeds this compressed encoded representation into the decoder to get a 100-dimensional output vector. The reconstruction process is usually not perfect because the compression process makes the reconstruction task very difficult.



A diagram of a standard autoencoder with a 1x4 input vector, a 1x2 intermediate state vector, and a 1x4 output vector. The color of the cell indicates the activation value. The output is an imperfect reconstruction of the input.

Explaining Sparse Autoencoders

How Sparse Autoencoders Work

A sparse autoencoder transforms an input vector into an intermediate vector that can have a dimension higher, equal, or lower than the input. When used for LLM, the intermediate vector is usually of higher dimension than the input. In this case, without additional constraints, the task is easy and the SAE can perfectly reconstruct the input using the identity matrix without any surprises. However, we will add constraints, one of which is to add a sparsity penalty to the training loss, which forces the SAE to create sparse intermediate vectors.

For example, we can expand the 100-dimensional input into a 200-dimensional encoded representation vector, and we can train the SAE to have only about 20 non-zero elements in the encoded representation.



Diagram of a sparse autoencoder. Note that the intermediate activations are sparse, with only 2 non-zero values.

We use SAE for intermediate activations within a neural network, which may contain many layers. During the forward pass, there are intermediate activations in each layer and between each layer.

For example, GPT-3 has 96 layers. During the forward pass, each token in the input has a 12,288-dimensional vector (a list of 12,288 numbers). This vector accumulates all the information the model uses to predict the next token as it processes each layer, but it is not transparent and makes it difficult to understand what information it contains.

We can use SAE to understand this intermediate activation. SAE is basically "matrix → ReLU activation → matrix".

For example, if the GPT-3 SAE has an expansion factor of 4 and its input activations are 12,288 dimensional, then its SAE encoded representations are 49,512 dimensional (12,288 x 4). The first matrix is ​​the encoder matrix of shape (12,288, 49,512) and the second matrix is ​​the decoder matrix of shape (49,512, 12,288). By multiplying the GPT activations with the encoder and using ReLU, we can get a 49,512 dimensional SAE encoded sparse representation because the SAE loss function encourages sparsity.

Generally speaking, we aim to have less than 100 non-zero values ​​in the SAE representation. By multiplying the SAE representation with the decoder, we get a 12,288-dimensional reconstructed model activation. This reconstruction does not perfectly match the original GPT activation because the sparsity constraint makes it difficult to achieve a perfect match.

Typically, an SAE is used for only one position in the model. For example, we could train an SAE on the intermediate activations between layers 26 and 27. To analyze the information contained in the outputs of all 96 layers of GPT-3, we could train 96 separate SAEs - one for each layer output. If we also wanted to analyze the various intermediate activations within each layer, we would need hundreds of SAEs. To get training data for these SAEs, we would feed the GPT model a large amount of different text and collect the intermediate activations at each selected position.

Here is a PyTorch reference implementation of SAE. The variables are annotated with shapes, an idea from Noam Shazeer, see: https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd. Note that different SAE implementations often have different biases, normalization schemes, or initialization schemes to maximize performance. The most common addition is some kind of constraint on the norm of the decoder vector. For more details, please visit the following implementations:

  • OpenAI:https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py#L16
  • SAELens:https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L97
  • dictionary_learning:https://github.com/saprmarks/dictionary_learning/blob/main/dictionary.py#L30

import torch

import torch.nn as nn

# D = d_model, F = dictionary_size

# e.g. if d_model = 12288 and dictionary_size = 49152

# then model_activations_D.shape = (12288,) and encoder_DF.weight.shape = (12288, 49152)

class SparseAutoEncoder (nn.Module):

A one-layer autoencoder.

def __init__(self, activation_dim: int, dict_size: int):

super ().__init__()

self.activation_dim = activation_dim

self.dict_size = dict_size

self.encoder_DF = nn.Linear (activation_dim, dict_size, bias=True)

self.decoder_FD = nn.Linear (dict_size, activation_dim, bias=True)

def encode (self, model_activations_D: torch.Tensor) -> torch.Tensor:

return nn.ReLU ()(self.encoder_DF (model_activations_D))

def decode (self, encoded_representation_F: torch.Tensor) -> torch.Tensor:

return self.decoder_FD (encoded_representation_F)

def forward_pass (self, model_activations_D: torch.Tensor) -> tuple [torch.Tensor, torch.Tensor]:

encoded_representation_F = self.encode (model_activations_D)

reconstructed_model_activations_D = self.decode (encoded_representation_F)

return reconstructed_model_activations_D, encoded_representation_F

The loss function of a standard autoencoder is based on the accuracy of the reconstruction of the input. The most straightforward way to introduce sparsity is to add a sparsity penalty to the SAE's loss function. The most common way to calculate this penalty is to take the L1 loss of the SAE's encoded representation (not the SAE weights) and multiply it by an L1 coefficient. This L1 coefficient is a key hyperparameter in SAE training because it determines the tradeoff between achieving sparsity and maintaining reconstruction accuracy.

Note that we are not optimizing for interpretability here. Instead, interpretable SAE features are a side effect of optimizing sparsity and reconstruction. Below is a reference loss function.

# B = batch size, D = d_model, F = dictionary_size

def calculate_loss (autoencoder: SparseAutoEncoder, model_activations_BD: torch.Tensor, l1_coeffient: float) -> torch.Tensor:

reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass (model_activations_BD)

reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow (2)

reconstruction_error_B = einops.reduce (reconstruction_error_BD, 'B D -> B', 'sum')

l2_loss = reconstruction_error_B.mean ()

l1_loss = l1_coefficient * encoded_representation_BF.sum ()

loss = l2_loss + l1_loss

return loss



Schematic diagram of the forward pass of a sparse autoencoder.

This is a single forward pass through a sparse autoencoder. First, we have a 1x4 model vector. Then we multiply it by a 4x8 encoder matrix to get a 1x8 encoded vector, and then apply ReLU to turn negative values ​​into zeros. This encoded vector is sparse. After that, we multiply it by an 8x4 decoder matrix to get a 1x4 imperfectly reconstructed model activation.

Hypothetical SAE Feature Demonstration

Ideally, every valid value in the SAE representation corresponds to an understandable component.

Let's take a hypothetical example to illustrate. Suppose a 12,288-dimensional vector [1.5, 0.2, -1.2, ...] represents "Golden Retriever" in GPT-3's view. SAE is a matrix of shape (49,512, 12,288), but we can also think of it as a collection of 49,512 vectors, each of which has a shape of (1, 12,288). If the 317 vectors of the SAE decoder have learned the same concept of "Golden Retriever" as GPT-3, then the decoder vector is roughly equal to [1.5, 0.2, -1.2, ...].

Whenever the 317th element of the SAE activation is non-zero, the vector corresponding to the "Golden Retriever" (and according to the magnitude of the 317th element) is added to the reconstructed activation. In mechanical interpretability terms, this can be neatly described as "the decoder vector corresponds to a linear representation of the features in the residual stream space".

In other words, the SAE with a 49,512-dimensional encoded representation has 49,512 features. The features consist of corresponding encoder and decoder vectors. The encoder vector is used to detect the internal concepts of the model while minimizing the interference of other concepts, while the decoder vector is used to represent the "real" feature direction. The researchers' experiments found that the encoder and decoder features of each feature are different, and the median cosine similarity is 0.5. In the figure below, the three red boxes correspond to a single feature.



Schematic of a sparse autoencoder, where the three red boxes correspond to SAE feature 1 and the green box corresponds to feature 4. Each feature has a 1x4 encoder vector, 1x1 feature activation, and 1x4 decoder vector. The reconstructed activations are constructed using only the decoder vectors from SAE features 1 and 4. If the red box represents "red color" and the green box represents "ball", then the model is likely to represent "red ball".

So how do we know what the hypothesized feature 317 represents? Currently, the practice is to find inputs that can maximize feature activation and give intuitive responses to their interpretability. Inputs that can activate every feature are usually interpretable.

For example, Anthropic trained an SAE on Claude Sonnet and found that different SAE features were activated by text and images related to the Golden Gate Bridge, neuroscience, and popular tourist attractions. Other features were activated by less obvious concepts, such as a feature of an SAE trained on Pythia that was activated by the concept of "the final token of a relative clause or prepositional phrase that modifies the subject of a sentence."

Since the SAE decoder vector has the same shape as the intermediate activations of the LLM, causal intervention can be performed simply by adding the decoder vector to the model activations. The strength of this intervention can be adjusted by multiplying the decoder vector by an expansion factor. When the Anthropic researchers added the "Golden Gate Bridge" SAE decoder vector to Claude's activations, Claude was forced to mention "Golden Gate Bridge" in every response.

Below is a reference implementation of a causal intervention using hypothetical feature 317. Similar to “Golden Gate Bridge” Claude, this very simple intervention forces the GPT-3 model to mention “Golden Retriever” in every response.

def perform_intervention (model_activations_D: torch.Tensor, decoder_FD: torch.Tensor, scale: float) -> torch.Tensor:

intervention_vector_D = decoder_FD [317, :]

scaled_intervention_vector_D = intervention_vector_D * scale

modified_model_activations_D = model_activations_D + scaled_intervention_vector_D

return modified_model_activations_D

Evaluation Challenges of Sparse Autoencoders

One of the main challenges with using SAEs is evaluation. We can train sparse autoencoders to interpret language models, but we don’t have a measurable underlying ground truth representation of natural language. Currently, evaluation is subjective, basically “we look at the activation inputs of a bunch of features and then make an intuition about the interpretability of those features.” This is a major limitation in the field of interpretability.

Researchers have found some common proxy metrics that seem to correspond to feature interpretability. The most commonly used are L0 and Loss Recovered. L0 is the average number of non-zero elements in the SAE's encoded intermediate representation. Loss Recovered replaces the GPT's original activations with the reconstructed activations and measures the additional loss of the imperfect reconstruction result. These two metrics are often a trade-off, as the SAE may choose a solution that results in a decrease in reconstruction accuracy in order to improve sparsity.

When comparing SAEs, a common approach is to plot the two variables and then examine the tradeoff between them. To achieve a better tradeoff, many new SAE methods, such as DeepMind's Gated SAE and OpenAI's TopK SAE, modify the sparsity penalty. The following figure is from DeepMind's Gated SAE paper. Gated SAE is represented by the red line, located in the upper left of the figure, which shows that it performs better on this tradeoff.



Gated SAE L0 and Loss Recovered

SAE metrics exist at multiple levels of difficulty. L0 and Loss Recovered are two proxy metrics. However, we do not use them during training because L0 is not differentiable and Loss Recovered is computationally expensive to calculate during SAE training. Instead, our training loss is determined by an L1 penalty term and the accuracy of reconstructing internal activations, rather than their impact on downstream losses.

The training loss function does not directly correspond to the proxy metric, and the proxy metric is just a proxy for the subjective assessment of feature interpretability. Since our real goal is to "understand how the model works", and the subjective interpretability assessment is just a proxy, there will be another layer of mismatch. Some important concepts in LLM may not be easy to explain, and we may ignore these concepts when blindly optimizing for interpretability.

Summarize

There is still a long way to go in the field of interpretability, but SAE is a real step forward. SAE enables interesting new applications, such as an unsupervised method for finding steering vectors like the "Golden Gate Bridge" steering vector. SAE also helps us find loops in language models more easily, which may be used to remove unnecessary biases inside the model.

The fact that SAEs can find interpretable features (even if the goal is just to recognize patterns in activations) suggests that they can reveal something meaningful. There is also evidence that LLMs are actually learning something meaningful, not just statistical regularities at the surface of memory.

SAE also represents an early milestone in what companies like Anthropic have been aiming for: “MRI for machine learning models.” SAE does not yet provide perfect understanding, but it can be used to detect bad behavior. The main challenges of SAE and SAE evaluation are not insurmountable, and there are many researchers working on this topic.

For a further introduction to sparse autoencoders, see Callum McDougal’s Colab notebook: https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab

https://www.reddit.com/r/MachineLearning/comments/1eeihdl/d_an_intuitive_explanation_of_sparse_autoencoders/

https://adamkarvonen.github.io/machine_learning/2024/06/11/sae-intuitions.html