Skip to main content

Multimodal Vision-Language Search on Satellite Images

ยท 3 min read
Amir Afshari

png

With the rapid advancement of remote sensing technologies, the availability of large-scale satellite image datasets has grown exponentially. These datasets contain invaluable information for various applications, including environmental monitoring, urban planning, and disaster management. However, extracting specific categories of objects, such as identifying all ships within a dataset of one million samples, presents a significant challenge due to the sheer volume of data and the complexity of manual analysis.

Solution

This task, which is overwhelming for human analysts, can be efficiently addressed using vector search techniques. By leveraging deep learning models to transform images into high-dimensional vectors and utilizing multimodal vision-text models such as CLIP, we can employ nearest neighbor search algorithms to quickly and accurately retrieve relevant images based on their content. For instance, you can search for "red ship floating on the sea" and use it as the query and the system provides you the appropriate instances of the dataset while there is no metadata available. This approach not only enhances the efficiency of data processing but also significantly improves speed of finding specific categories within vast datasets.

Good to know

  • I did not fine-tune CLIP for remote sensing text-image pairs, but it still works fine.
  • By fine-tuning, we can enhance the accuracy.
  • It's not a production-level (or even near-production-level) solution, so there is plenty of room for improvement in both speed and accuracy.
  • We can use vector databases such as Weaviate or Redis instead of a simple Python list.
  • The dataset is limited to approximately 14,000 samples (a combination of AID, FAIR1M_partial, RESICS_partial, ...).
  • We can binarize the vectors to improve speed.
  • You can find the code on Github
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import models, transforms
from sklearn.neighbors import NearestNeighbors
import clip


import random
import os
import glob

from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

Model

model, preprocess = clip.load("ViT-B/32", device=device)

Image Embedding

# Walk into directoies to find images and convert them to vectors
image_paths = []
features = []
for root, dirs, files in os.walk('datasets/'):
    for file in files:
        if file.endswith((".jpg", ".tif", ".png")):
            file_path = os.path.join(root, file)
            image_paths.append(file_path)

            img = Image.open(file_path)
            
            with torch.no_grad():
                feature = model.encode_image(preprocess(img).unsqueeze(0).to(device)).detach().cpu().numpy()
            features.append(feature)


print(len(features), ' Images Found!')
f = np.concatenate(features, axis=0)

14099 Images Found!

  • Took ~ 2m 30s on NVIDIA 3060

Dataset Visualization

# Here we randomly select an image from dataset as our query image
randompath = image_paths.copy()
random.shuffle(randompath)

n = 50


# Plot query image
rows, columns = 10, 5
fig, axs = plt.subplots(rows, columns, figsize=(20, 50), squeeze=True)


n = 0
for i in range(rows):
    for j in range(columns):

        axs[i][j].imshow(Image.open(randompath[n]))
        axs[i][j].set_title(f'Instance {n}')
        n += 1
plt.show()

png png png

# Text Embedding (Query feature)
query = 'stadium'
query = clip.tokenize(query).to(device)
query = model.encode_text(query).detach().cpu()
# NN Search
k = 50
neigh = NearestNeighbors(n_neighbors=k, algorithm='brute')
neigh.fit(f)
distances, indices = neigh.kneighbors(query)
rows, columns = 10, 5
fig, axs = plt.subplots(rows, columns, figsize=(20, 50))


n = 0
for i in range(rows):
    for j in range(columns):

        axs[i][j].imshow(Image.open(image_paths[indices[0][n]]))
        axs[i][j].set_title(f'Nearest Neighbor {n}')
        n += 1
plt.show()

png png png