Lecture 01: Clustering class demo

Lecture 01: Clustering class demo#

Let’s cluster images!!#

For this demo, I’m going to use two image datasets:

  1. A small subset of 200 Bird Species with 11,788 Images dataset (available here)

  2. A tiny subset of Food-101 (available here)

To run the code below, you need to install pytorch and torchvision in the course conda environment.

conda install pytorch torchvision -c pytorch

import os
import random
import sys
import time

import numpy as np
import pandas as pd

sys.path.append(os.path.join(os.path.abspath(".."), "code"))
from plotting_functions import *

DATA_DIR = os.path.join(os.path.abspath(".."), "data/")

import torch
import torchvision
from torchvision import datasets, models, transforms, utils
from PIL import Image
import matplotlib.pyplot as plt
import random
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 10
      7 import pandas as pd
      9 sys.path.append(os.path.join(os.path.abspath(".."), "code"))
---> 10 from plotting_functions import *
     12 DATA_DIR = os.path.join(os.path.abspath(".."), "data/")
     14 import torch

File ~/MDS/2024-25/563/DSCI_563_unsup-learn/lectures/code/plotting_functions.py:7
      5 import matplotlib.pyplot as plt
      6 from matplotlib.colors import ListedColormap, colorConverter, LinearSegmentedColormap
----> 7 from scipy.spatial import distance
      8 from sklearn.metrics import euclidean_distances
      9 from sklearn.manifold import MDS

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/spatial/__init__.py:110
      1 """
      2 =============================================================
      3 Spatial algorithms and data structures (:mod:`scipy.spatial`)
   (...)
    107    QhullError
    108 """  # noqa: E501
--> 110 from ._kdtree import *
    111 from ._ckdtree import *  # type: ignore[import-not-found]
    112 from ._qhull import *

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/spatial/_kdtree.py:4
      1 # Copyright Anne M. Archibald 2008
      2 # Released under the scipy license
      3 import numpy as np
----> 4 from ._ckdtree import cKDTree, cKDTreeNode  # type: ignore[import-not-found]
      6 __all__ = ['minkowski_distance_p', 'minkowski_distance',
      7            'distance_matrix',
      8            'Rectangle', 'KDTree']
     11 def minkowski_distance_p(x, y, p=2):

File _ckdtree.pyx:11, in init scipy.spatial._ckdtree()

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/sparse/__init__.py:315
    312 from ._sputils import get_index_dtype, safely_cast_index_arrays
    314 # For backward compatibility with v0.19.
--> 315 from . import csgraph
    317 # Deprecated namespaces, to be removed in v2.0.0
    318 from . import (
    319     base, bsr, compressed, construct, coo, csc, csr, data, dia, dok, extract,
    320     lil, sparsetools, sputils
    321 )

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/sparse/csgraph/__init__.py:187
    158 __docformat__ = "restructuredtext en"
    160 __all__ = ['connected_components',
    161            'laplacian',
    162            'shortest_path',
   (...)
    184            'csgraph_to_masked',
    185            'NegativeCycleError']
--> 187 from ._laplacian import laplacian
    188 from ._shortest_path import (
    189     shortest_path, floyd_warshall, dijkstra, bellman_ford, johnson, yen,
    190     NegativeCycleError
    191 )
    192 from ._traversal import (
    193     breadth_first_order, depth_first_order, breadth_first_tree,
    194     depth_first_tree, connected_components
    195 )

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/sparse/csgraph/_laplacian.py:7
      5 import numpy as np
      6 from scipy.sparse import issparse
----> 7 from scipy.sparse.linalg import LinearOperator
      8 from scipy.sparse._sputils import convert_pydata_sparse_to_scipy, is_pydata_spmatrix
     11 ###############################################################################
     12 # Graph laplacian

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/sparse/linalg/__init__.py:129
      1 """
      2 Sparse linear algebra (:mod:`scipy.sparse.linalg`)
      3 ==================================================
   (...)
    126 
    127 """
--> 129 from ._isolve import *
    130 from ._dsolve import *
    131 from ._interface import *

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/sparse/linalg/_isolve/__init__.py:4
      1 "Iterative Solvers for Sparse Linear Systems"
      3 #from info import __doc__
----> 4 from .iterative import *
      5 from .minres import minres
      6 from .lgmres import lgmres

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/sparse/linalg/_isolve/iterative.py:5
      3 from scipy.sparse.linalg._interface import LinearOperator
      4 from .utils import make_system
----> 5 from scipy.linalg import get_lapack_funcs
      7 __all__ = ['bicg', 'bicgstab', 'cg', 'cgs', 'gmres', 'qmr']
     10 def _get_atol_rtol(name, b_norm, atol=0., rtol=1e-5):

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/linalg/__init__.py:203
      1 """
      2 ====================================
      3 Linear algebra (:mod:`scipy.linalg`)
   (...)
    200 
    201 """  # noqa: E501
--> 203 from ._misc import *
    204 from ._cythonized_array_utils import *
    205 from ._basic import *

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/linalg/_misc.py:3
      1 import numpy as np
      2 from numpy.linalg import LinAlgError
----> 3 from .blas import get_blas_funcs
      4 from .lapack import get_lapack_funcs
      6 __all__ = ['LinAlgError', 'LinAlgWarning', 'norm']

File ~/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/linalg/blas.py:213
    210 import numpy as np
    211 import functools
--> 213 from scipy.linalg import _fblas
    214 try:
    215     from scipy.linalg import _cblas

ImportError: dlopen(/Users/kvarada/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/linalg/_fblas.cpython-312-darwin.so, 0x0002): Library not loaded: @rpath/libgfortran.5.dylib
  Referenced from: <0B9C315B-A1DD-3527-88DB-4B90531D343F> /Users/kvarada/miniforge3/envs/jbook/lib/libopenblas.0.dylib
  Reason: tried: '/Users/kvarada/miniforge3/envs/jbook/lib/libgfortran.5.dylib' (duplicate LC_RPATH '@loader_path'), '/Users/kvarada/miniforge3/envs/jbook/lib/libgfortran.5.dylib' (duplicate LC_RPATH '@loader_path'), '/Users/kvarada/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/linalg/../../../../libgfortran.5.dylib' (duplicate LC_RPATH '@loader_path'), '/Users/kvarada/miniforge3/envs/jbook/lib/python3.12/site-packages/scipy/linalg/../../../../libgfortran.5.dylib' (duplicate LC_RPATH '@loader_path'), '/Users/kvarada/miniforge3/envs/jbook/bin/../lib/libgfortran.5.dylib' (duplicate LC_RPATH '@loader_path'), '/Users/kvarada/miniforge3/envs/jbook/bin/../lib/libgfortran.5.dylib' (duplicate LC_RPATH '@loader_path'), '/usr/local/lib/libgfortran.5.dylib' (no such file), '/usr/lib/libgfortran.5.dylib' (no such file, not in dyld cache)

Let’s start with small subset of birds dataset. You can experiment with a bigger dataset if you like.

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device
device(type='mps')
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(seed=42)
import glob
IMAGE_SIZE = 224
def read_img_dataset(data_dir):     
    data_transforms = transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),     
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),            
        ])
               
    image_dataset = datasets.ImageFolder(root=data_dir, transform=data_transforms)
    dataloader = torch.utils.data.DataLoader(
         image_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
    )
    dataset_size = len(image_dataset)
    class_names = image_dataset.classes
    inputs, classes = next(iter(dataloader))
    return inputs, classes
def plot_sample_imgs(inputs):
    plt.figure(figsize=(10, 70)); plt.axis("off"); plt.title("Sample Training Images")
    plt.imshow(np.transpose(utils.make_grid(inputs, padding=1, normalize=True),(1, 2, 0)));
data_dir = DATA_DIR + "/birds"
file_names = [image_file for image_file in glob.glob(data_dir + "/*/*.jpg")]
n_images = len(file_names)
BATCH_SIZE = n_images  # because our dataset is quite small
birds_inputs, birds_classes = read_img_dataset(data_dir)
X_birds = birds_inputs.numpy()
plot_sample_imgs(birds_inputs[0:24,:,:,:])
plt.show()
../../_images/c44b394a93a0be6710907685cd942592fdfffe2414ec9b884e38b2dbd3a8d4ae.png

For clustering we need to calculate distances between points. So we need a vector representation for each data point. A simplest way to create a vector representation of an image is by flattening the image.

flatten_transforms = transforms.Compose([    
                    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),    
                    transforms.Lambda(torch.flatten)])
flatten_images = datasets.ImageFolder(root='../data/birds', transform=flatten_transforms)
flatten_dataloader = torch.utils.data.DataLoader(
        flatten_images, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
    )
flatten_train, y_train = next(iter(flatten_dataloader))
flatten_images = flatten_train.numpy()
image_shape=[3,224,224]
img = flatten_images[20].reshape(image_shape)
plt.imshow(np.transpose(img / 2 + 0.5, (1, 2, 0)));
../../_images/10891a84cda7604057706e751bcefa796733243a9ac64bff9c2ba208bdcdbd5c.png
flatten_images.shape # 224 by 224 images with 3 color channels
(176, 150528)
from sklearn.cluster import KMeans
k = 3
km_flatten = KMeans(k, n_init='auto', random_state=123)
km_flatten.fit(flatten_images)
KMeans(n_clusters=3, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
km_flatten.cluster_centers_.shape
(3, 150528)
flatten_images.shape
(176, 150528)
unflatten_inputs = np.array([img.reshape(image_shape) for img in flatten_images])
for cluster in range(k):
    # user-defined functions defined in ../code/plotting_functions.py
    get_cluster_images(km_flatten, flatten_images, unflatten_inputs, cluster, n_img=5)
158
Image indices:  [158  65  48 125  95]
../../_images/18f1f57e589a1004e32ac56a1b543debf67f13223499d378dc2cf03e0be5759f.png
165
Image indices:  [165  94  77 152 108]
../../_images/abd2eafa133d66016a7473da761eb2ab6343e9324563c4dbbbf4c16e9f29c23d.png
156
Image indices:  [156  89 100  25 133]
../../_images/e9b250bfe0fd571d968252ba3efdc64e61410b3101ec566cac9178511e3695c6.png

Let’s try clustering with GMMs

from sklearn.mixture import GaussianMixture

gmm_flatten = GaussianMixture(n_components=k,covariance_type='diag', random_state=123)
gmm_flatten.fit(flatten_images)
GaussianMixture(covariance_type='diag', n_components=3, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
for cluster in range(k):
    # user-defined functions defined in ../code/plotting_functions.py
    get_cluster_images(gmm_flatten, flatten_images, unflatten_inputs, cluster=cluster, n_img=5)
Image indices:  [ 48 106 104 122  87]
../../_images/2acd9112e4e1faf5b8b17c4a3770569b4f8eebc4678d9ab7fc470d939210476d.png
Image indices:  [ 56  55  54 113 175]
../../_images/1cbeb829b33503a63b34b495fecd6ce12f2553f3c0df925dbea3629505fa8300.png
Image indices:  [114  39 126  90  64]
../../_images/2862f5a3dd8114ef0b00ab570cba05584458769a7886910170729af91ccda680.png

We still see some mis-categorizations. It seems like when we flatten images, clustering doesn’t seem that great.

Let’s try out a different input representation. Let’s use transfer learning as a feature extractor with a pre-trained vision model. For each image in our dataset we’ll pass it through a pretrained network and get a representation from the last layer, before the classification layer given by the pre-trained network.

We see some mis-categorizations.

How about trying out a different input representation? Let’s use transfer learning as a feature extractor with a pre-trained vision model. For each image in our dataset we’ll pass it through a pretrained network and get a representation from the last layer, before the classification layer given by the pre-trained network.

def get_features(model, inputs):
    """Extract output of densenet model"""
    model.eval()
    with torch.no_grad():  # turn off computational graph stuff        
        Z = model(inputs).detach().numpy()         
    return Z
densenet = models.densenet121(weights="DenseNet121_Weights.IMAGENET1K_V1")
densenet.classifier = torch.nn.Identity()  # remove that last "classification" layer
Z_birds = get_features(densenet, birds_inputs)
Z_birds.shape
(176, 1024)
pd.DataFrame(Z_birds)
0 1 2 3 4 5 6 7 8 9 ... 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
0 0.000221 0.005660 0.002462 0.004169 0.097525 0.287082 0.000673 0.004866 0.281340 0.000284 ... 0.191269 0.024719 0.415103 1.336383 0.057381 0.522568 0.434491 0.217985 0.406575 0.283343
1 0.000184 0.006229 0.002463 0.001197 0.102545 0.178865 0.000516 0.005317 0.183479 0.000223 ... 0.262333 0.497820 0.192020 0.301240 0.169246 0.062231 0.840451 3.160916 0.018541 0.194022
2 0.000444 0.007750 0.002796 0.001045 0.126620 0.224383 0.000616 0.002492 0.092823 0.000106 ... 0.265266 0.429609 0.253233 0.102961 0.089436 0.212794 0.470108 2.075980 0.330177 0.401309
3 0.000131 0.005346 0.001581 0.001190 0.124439 0.318437 0.000681 0.002097 0.062963 0.000183 ... 1.440256 2.710825 0.023943 0.199657 0.475663 0.060038 0.732850 0.527692 0.293737 0.043358
4 0.000338 0.006431 0.004826 0.001502 0.127398 0.353062 0.000712 0.002751 0.203407 0.000304 ... 0.077847 0.568676 0.662725 0.075138 0.195219 1.579705 1.172825 1.220631 1.092952 2.879782
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
171 0.000093 0.002115 0.005041 0.002880 0.083908 0.650543 0.000448 0.006777 0.135922 0.000169 ... 0.198178 0.145257 0.667060 0.701092 0.338362 0.314434 0.632372 0.738904 0.036569 0.711834
172 0.000397 0.005551 0.003777 0.001320 0.113525 0.468127 0.001020 0.002231 0.079646 0.000074 ... 0.707488 2.385176 0.006436 1.460262 0.248579 0.417530 1.266310 2.967516 0.447070 0.510578
173 0.000258 0.002662 0.001886 0.000997 0.092180 0.201172 0.000416 0.003316 0.206956 0.000199 ... 2.652632 0.265732 0.157335 0.033217 0.048084 0.306286 0.742069 0.603324 0.534728 0.403861
174 0.000233 0.003904 0.005026 0.003179 0.115082 0.625989 0.000691 0.003087 0.216596 0.000218 ... 0.402888 0.094320 0.611828 1.394516 0.501248 0.985927 0.172135 0.925321 0.561355 1.262532
175 0.000150 0.005762 0.004460 0.002486 0.114589 0.769496 0.000627 0.004954 0.234761 0.000121 ... 0.283001 0.508711 0.416403 1.006729 0.055094 1.160128 0.962558 1.547616 0.307633 0.756805

176 rows × 1024 columns

Do we get better clustering with this representation?

from sklearn.cluster import KMeans

k = 3
km = KMeans(n_clusters=k, n_init='auto', random_state=123)
km.fit(Z_birds)
KMeans(n_clusters=3, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
km.cluster_centers_.shape
(3, 1024)
for cluster in range(k):
    # user-defined functions defined in ../code/plotting_functions.py
    get_cluster_images(km, Z_birds, X_birds, cluster, n_img=6)
103
Image indices:  [103  23  86 162 168 122]
../../_images/5f8fcfec80558506c50c76abb3b90b82b80e572fb4de7d872b4aab2fd352bb42.png
55
Image indices:  [55 31 53 15 88 84]
../../_images/224f2cc6392dc346ae5428ebab47122413841280d9369658d49091845675f122.png
120
Image indices:  [120   5  11  14  22  69]
../../_images/504d272c0825180b701e31a1293266668e339540229ed75cc55473308a1ea2b8.png

KMeans seems to be doing a good job. But cluster centers are not interpretable at all now. Let’s try GMMs.

gmm = GaussianMixture(n_components=k, random_state=123)
gmm.fit(Z_birds)
GaussianMixture(n_components=3, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
gmm.weights_
array([0.34090909, 0.32386364, 0.33522727])
for cluster in range(k):
    # user-defined functions defined in ../code/plotting_functions.py
    get_cluster_images(gmm, Z_birds, X_birds, cluster, n_img=6)
Image indices:  [107 106  42 103 135  87]
../../_images/78588d7e6a7bb1e9d4af5a6389e8fdc75520153b016e5ce4aac47235f7439a82.png
Image indices:  [ 84  79  78  77 100 175]
../../_images/523285036402a49531348868ac2b4fc5d1e14b77a36630802490e5ba5da4a3cc.png
Image indices:  [137  61  28 141  25 124]
../../_images/7eedea09dba9392c0d9d22b5ac853440cf5a1578448bc0ac6d3aad13bdf21100.png

Cool! Both models are doing a great job with this representation!! This dataset seems easier, as the birds have very distinct colors. Let’s try a bit more complicated dataset.

data_dir = DATA_DIR + "food"
file_names = [image_file for image_file in glob.glob(data_dir + "/*/*.jpg")]
n_images = len(file_names)
BATCH_SIZE = n_images  # because our dataset is quite small
food_inputs, food_classes = read_img_dataset(data_dir)
n_images
350
X_food = food_inputs.numpy()
plot_sample_imgs(food_inputs[0:24,:,:,:])
../../_images/188824e656a7d5dd1c646dd80ab54b0393f9ee3309d91ea00933d86e19444854.png
Z_food = get_features(
    densenet, food_inputs, 
)
Z_food.shape
(350, 1024)
from sklearn.cluster import KMeans

k = 5
km = KMeans(n_clusters=k, n_init='auto', random_state=123)
km.fit(Z_food)
KMeans(n_clusters=5, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
km.cluster_centers_.shape
(5, 1024)
for cluster in range(k):
    get_cluster_images(km, Z_food, X_food, cluster, n_img=6)
84
Image indices:  [ 84 169 328   0 143  12]
../../_images/fd98bb6aae8848bac5c693fce4b6d05a7e6eb1798f584930279f1da9d3a5c620.png
263
Image indices:  [263  80 257 301  44 326]
../../_images/79759c7a074f28f6af6fbaa0f88f85033f459d7794aad9c0111a1d38bbdb9763.png
188
Image indices:  [188   1 339 273  55 238]
../../_images/8adc7dbdb435574f66ebf0be5745f7b9f3230b3f29ef25699160869fb3434947.png
282
Image indices:  [282 150 177 138 116 123]
../../_images/ebe6732473a30b5bfe4ac8deb6ba13fa2fa4bfdbd79310181875ba008e8eed82.png
20
Image indices:  [ 20  39 332  15 226 322]
../../_images/0e7dd804eefedcd7755ddc3effa398e4eea8c8f9dbe8199db4291f127be7e293.png

There are some mis-classifications but overall it seems pretty good! You can experiment with

  • Different values for number of clusters

  • Different pre-trained models

  • Other possible representations

  • Different image datasets

See an example of using K-Means clustering on customer segmentation in AppendixA.