Skip to content

Sparse Autoencoders (SAEs) for Discovery

Introduction

Trait discovery is a challenging problem in many domains, as it requires identifying meaningful and interpretable features across a population. Large foundation models have demonstrated the ability to learn effective representations of data, capturing intricate patterns and relationships. In this tutorial, we explore how sparse autoencoders (SAEs) can be used to extract possible traits from these representations, accelerating insights and discoveries.

Learning Objectives

By the end of this tutorial, you will be able to:

  1. Train sparse autoencoders (SAEs) on a dataset.
  2. Choose an optimal SAE for trait extraction/visualization.
  3. Visualize SAE features on images.
  4. Understand how to identify features that are possibly meaningful, and the limitations of this approach.

Prerequisites

  • Python version: >= 3.12
  • Packages: managed using uv
  • Data: iNaturalist 2021 mini, predownloaded to Cyverse.
  • Prior knowledge: Comfort using the linux terminal. No machine learning experience is necessary, though basic knowledge will help with understanding.

Setup

In the CyVerse Discovery Environmnet, launch "Jupyter Lab PyTorch GPU".

Wait for "Launching VICE app: jupyter-lab-pytorch-gpu" to complete.

Click "Terminal" to start a new terminal session. This should place you in ~/data-store.

Clone the repository for the embedding explorer application:

git clone https://github.com/Imageomics/saev-demo.git

Enter the repository directory:

cd saev-demo

Install uv for fast environment setup:

curl -LsSf https://astral.sh/uv/install.sh | sh

Reload your shell environment:

source $HOME/.local/bin/env

Create a virtual environment with an appropriate Python version:

uv sync

Background

What is an embedding?

An embedding is a numeric summary of an image (or part of an image).

Instead of storing raw pixels, a model converts each image into a vector (a list of numbers) or set of vectors that captures useful visual patterns. Images or parts of images that look similar often have similar embeddings.

You can think of embeddings as a compact "feature fingerprint" for each image or each part of an image. In this tutorial, SAEs are trained on these fingerprints.

Vision Transformers (ViTs)

Vision Transformers are image models that process an image as many small patches.

Vision Transformer patch-based representation

For each patch, the model builds internal features across many layers. We use one of these internal layers as input to the SAE, because these features already contain rich biological structure (shape, texture, color, etc.).

In this tutorial, we use DINOv2 ViT-S embeddings generated by the saev-demo tooling.

Sparse Autoencoders

A sparse autoencoder (SAE) is a model that learns to decompose embeddings into a set of sparse features (meaning for a single example, most features will be 0). An SAE is trained with two objectives in mind:

  • Reconstruct the original embedding well.
  • Use only a small number of active features per image (sparsity).

Practically, the sparse features learned by an SAE tend to be more interpretable than the original embeddings. For example, a feature may strongly activate on a repeated visual trait (like fin shape or wing pattern). These features can then be treated as trait candidates, needing interpretation from a domain scientist.

Dataset: Fish-Vista

Fish-Vista is an Imageomics fish dataset with species and trait labels. It is useful for testing whether discovered features align with known biological traits, as segmented morphological traits are provided for a large portion of the images. We provide a sparse autoencoder trained on Fish-Vista, and use a small subset of Fish-Vista for visualization.

Dataset: iNaturalist 2021 Mini Validation Set

iNaturalist 2021 Mini is a subset of iNaturalist images across many taxa. These images are typically taken outside of a lab setting, so they may not be as clean as data collected from labratory or museum specimens. For this project, we train an SAE on iNaturalist 2021 Mini's training set.

Steps

Here, we will walk through the process of training an SAE, and then visualize the features of fully-trained SAEs. While you will not train an SAE during this workshop (due to time constraints), you will learn to use all of the commands you would use to train your own SAE. You will also be able to visualize fully trained SAE features on a small subset of the Fish-Vista dataset.

Generating shards

First, generate embedding shards from your image folder. Shards are saved blocks of model activations that make SAE training efficient.

Run from the repository root:

mkdir saev/shards
uv run launch.py shards --shards-root saev/shards/ --family dinov2 --ckpt dinov2_vits14_reg --d-model 384 --n-workers 0 data:fake-img 

Note:

  • When using this on real data, replace data:fake-img with data:img-folder. Then, provide the path to your dataset with the argument --data.root <dataset directory>
  • The command prints the shard directory that was created. Record this, as you will use that path in training.

Training SAEs

Using the following command, we train an SAE on the model activations that we generated above. Note that we set the SAE dimension here to be 16 times bigger than the original model's activation dimension. This is a standard practice when training SAEs, but changing this number can produce better or worse results, depending on your model and dataset.

uv run launch.py train --n-train 10000 --n-val 10 --runs-root ./saev/runs/ --train-data.shards ./saev/shards/<insert shard hash>/ --train-data.layer -2 --val-data.shards ./saev/shards/<insert shard hash>/ --val-data.layer -2 --sae.d-model 384 --sae.d-sae 6144 sae.activation:batch-top-k sae.activation.sparsity:no-sparsity

Notes:

  • Use the shard path from the previous step.
  • Training produces a run folder under ./runs. The provided training command will not produce any output, but we have provided two trained SAEs for visualization. In practice, run data will automatically be output to folders inside this folder.

Pareto analysis (optional)

After training, we may have several SAEs trained with different hyperparameters. To select the best SAE, we perform a pareto analysis. We want low reconstruction error (MSE) and high sparsity (via k), as SAEs that can reconstruct model activations from a few sparse features typically give better features. The models on the Pareto frontier (meaning lowest MSE and highest sparsity) can be found using the following command:

(Note that you would need a free account with Weights & Biases and an API key to run this.)

uv run python demo/pareto_runs.py --runs-root ./runs --out-png demo/pareto_k_vs_mse.png --out-csv demo/pareto_k_vs_mse.csv --annotate

This produces:

  • A plot (.png) showing all runs and the frontier.
  • A table (.csv) with metrics for each run.

Pick one run on (or near) the frontier for inference. In this demo we will not do Pareto analysis, but this step is encouraged in practice, specifically when trying multiple sets of hyperparameters.

Generating Sparse Features

Typically, you would now run inference with your selected SAE. This computes sparse feature activations for the dataset. This step is typically required prior to visualization. Here, we have a very reduced dataset that we will compute the sparse activations for on the fly, so this step will be skipped.

uv run launch.py inference --run ./saev/runs/<chosen run path here>/ --data.shards ./saev/shards/<shard path here>/

Visualization

Finally, we can visualize the SAE features on our input dataset. To start the Streamlit visualization app:

uv run streamlit run demo/feature_vis.py \
  --server.headless true \
  --server.enableCORS false \
  --server.enableXsrfProtection false

At the address for the visualization app, the app will now load after a moment.

In your brower's URL bar, you will see something like "https://a12345abc.cyverse.run/lab". Note the prefix "a12345abc".

In another browser tab, paste "https://placeholder.cyverse.run/proxy/8501/" (it will be 404 Not Found for now).

Replace the proxy URL prefix "placeholder" with that of your Jupyter Lab session (i.e. "https://a12345abc.cyverse.run/proxy/8501/" from the example). For now, this page will still not work, but it will load once we run the visualization app later in this demo.

In the app, you can see several fields to the left that can be modified to visualize SAE features for different SAE models on the Fish-Vista subset. First, use the default settings and press the "Generate visualization" button. This will calculate the SAE features for each image patch and overlay them on top of the input images. Bright yellow means that a feature has a high value for that patch, while a faint purple or no highlight means that the feature does not activate highly for that patch.

Interpretation tips:

  • Look for consistent biological patterns across top images for a feature.
  • Check for dataset artifacts (backgrounds, lighting, labels) that may drive activations.
  • Treat discovered features as hypotheses that need validation.

To see how the training dataset influences SAE features, change the run directory box from ./saev/runs/inat to ./saev/runs/fish, and press the generate visualization button again. You should see that the new features are much better, corresponding more closely to biologically relevant traits.

Summary

In this tutorial, you generated ViT embedding shards, trained an SAEs, selected a run using Pareto analysis, produced sparse feature activations, and visualized feature behavior on images.

SAEs can help surface candidate traits and structure in large image collections. The outputs are useful for discovery, but interpretation should be done carefully and ideally cross-checked with domain expertise and metadata.