GSOC WEEK 3

Here in this week completely we worked on image encoder part , where orginal images as sam model is trained on the images of color data , the data need to shifted to the original decoder accepted shape . so we are using some functions from the sam encoder file and converting the files as per requirment.

# %% import packages
import numpy as np
import os
from glob import glob
import pandas as pd

join = os.path.join
from skimage import transform, io, segmentation
from tqdm import tqdm
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import argparse

# set up the parser
parser = argparse.ArgumentParser(description="preprocess grey and RGB images")

# add arguments to the parser
parser.add_argument(
    "-i",
    "--img_path",
    type=str,
    default="data/nucleus_data/train/images",
    help="path to the images",
)
parser.add_argument(
    "-gt",
    "--gt_path",
    type=str,
    default="data/nucleus_data/train/labels",
    help="path to the ground truth (gt)",
)

parser.add_argument(
    "--csv",
    type=str,
    default=None,
    help="path to the csv file",
)

parser.add_argument(
    "-o",
    "--npz_path",
    type=str,
    default="data/demo2D",
    help="path to save the npz files",
)
parser.add_argument(
    "--data_name",
    type=str,
    default="demo2d",
    help="dataset name; used to name the final npz file, e.g., demo2d.npz",
)
parser.add_argument("--image_size", type=int, default=256, help="image size")
parser.add_argument(
    "--img_name_suffix", type=str, default=".png", help="image name suffix"
)
parser.add_argument("--label_id", type=int, default=255, help="label id")
parser.add_argument("--model_type", type=str, default="vit_b", help="model type")
parser.add_argument(
    "--checkpoint",
    type=str,
    default="work_dir/SAM/sam_vit_b_01ec64.pth",
    help="checkpoint",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--seed", type=int, default=2023, help="random seed")

# parse the arguments
args = parser.parse_args()

# convert 2d grey or rgb images to npz file
imgs = []
gts = []
img_embeddings = []

# set up the model
# get the model from sam_model_registry using the model_type argument
# and load it with checkpoint argument
# download save the SAM checkpoint.
# [https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth](VIT-B SAM model)

sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to(
    args.device
)


def process(gt_name: str, image_name: str):
    if image_name == None:
        image_name = gt_name.split(".")[0] + args.img_name_suffix
    gt_data = io.imread(join(args.gt_path, gt_name))
    # if it is rgb, select the first channel
    if len(gt_data.shape) == 3:
        gt_data = gt_data[:, :, 0]
    assert len(gt_data.shape) == 2, "ground truth should be 2D"

    # resize ground truch image
    gt_data = transform.resize(
        gt_data == args.label_id,
        (args.image_size, args.image_size),
        order=0,
        preserve_range=True,
        mode="constant",
    )
    # convert to uint8
    gt_data = np.uint8(gt_data)

    if np.sum(gt_data) > 100:  # exclude tiny objects
        """Optional binary thresholding can be added"""
        assert (
            np.max(gt_data) == 1 and np.unique(gt_data).shape[0] == 2
        ), "ground truth should be binary"

        image_data = io.imread(join(args.img_path, image_name))
        # Remove any alpha channel if present.
        if image_data.shape[-1] > 3 and len(image_data.shape) == 3:
            image_data = image_data[:, :, :3]
        # If image is grayscale, then repeat the last channel to convert to rgb
        if len(image_data.shape) == 2:
            image_data = np.repeat(image_data[:, :, None], 3, axis=-1)
        # nii preprocess start
        lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(
            image_data, 99.5
        )
        image_data_pre = np.clip(image_data, lower_bound, upper_bound)
        # min-max normalize and scale
        image_data_pre = (
            (image_data_pre - np.min(image_data_pre))
            / (np.max(image_data_pre) - np.min(image_data_pre))
            * 255.0
        )
        image_data_pre[image_data == 0] = 0

        image_data_pre = transform.resize(
            image_data_pre,
            (args.image_size, args.image_size),
            order=3,
            preserve_range=True,
            mode="constant",
            anti_aliasing=True,
        )
        image_data_pre = np.uint8(image_data_pre)

        imgs.append(image_data_pre)

        assert np.sum(gt_data) > 100, "ground truth should have more than 100 pixels"

        gts.append(gt_data)

        # resize image to 3*1024*1024
        sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
        resize_img = sam_transform.apply_image(image_data_pre)
        resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(
            args.device
        )
        input_image = sam_model.preprocess(
            resize_img_tensor[None, :, :, :]
        )  # (1, 3, 1024, 1024)
        assert input_image.shape == (
            1,
            3,
            sam_model.image_encoder.img_size,
            sam_model.image_encoder.img_size,
        ), "input image should be resized to 1024*1024"
        # pre-compute the image embedding
        with torch.no_grad():
            embedding = sam_model.image_encoder(input_image)
            img_embeddings.append(embedding.cpu().numpy()[0])


if args.csv != None:
    # if data is presented in csv format
    # columns must be named image_filename and mask_filename respectively
    try:
        os.path.exists(args.csv)
    except FileNotFoundError as e:
        print(f"File {args.csv} not found!!")

    df = pd.read_csv(args.csv)
    bar = tqdm(df.iterrows(), total=len(df))
    for idx, row in bar:
        process(row.mask_filename, row.image_filename)

else:
    # get all the names of the images in the ground truth folder
    names = sorted(os.listdir(args.gt_path))
    # print the number of images found in the ground truth folder
    print("image number:", len(names))
    for gt_name in tqdm(names):
        process(gt_name, None)

# create a directory to save the npz files
save_path = args.npz_path + "_" + args.model_type
os.makedirs(save_path, exist_ok=True)


# save all 2D images as one npz file: ori_imgs, ori_gts, img_embeddings
# stack the list to array
print("Num. of images:", len(imgs))
if len(imgs) > 1:
    imgs = np.stack(imgs, axis=0)  # (n, 256, 256, 3)
    gts = np.stack(gts, axis=0)  # (n, 256, 256)
    img_embeddings = np.stack(img_embeddings, axis=0)  # (n, 1, 256, 64, 64)
    np.savez_compressed(
        join(save_path, args.data_name + ".npz"),
        imgs=imgs,
        gts=gts,
        img_embeddings=img_embeddings,
    )
    # save an example image for sanity check
    idx = np.random.randint(imgs.shape[0])
    img_idx = imgs[idx, :, :, :]
    gt_idx = gts[idx, :, :]
    bd = segmentation.find_boundaries(gt_idx, mode="inner")
    img_idx[bd, :] = [255, 0, 0]
    io.imsave(save_path + ".png", img_idx, check_contrast=False)
else:
    print(
        "Do not find image and ground-truth pairs. Please check your dataset and argument settings"
    )

Explanation of the code

Certainly! The complete code performs the preprocessing of gray and RGB images and saves them in an npz file format. Here is a detailed explanation of each section of the code:

  1. Importing packages:

    • The code imports necessary packages such as numpy, os, glob, pandas, skimage, tqdm, and torch. These packages provide functionalities for file operations, image processing, progress tracking, and deep learning.
  2. Parsing command-line arguments:

    • The code sets up an argument parser using argparse.ArgumentParser() to handle command-line arguments.
    • Several arguments are added to the parser, including paths to the input images (img_path) and ground truth (gt_path), CSV file (csv), output npz file path (npz_path), dataset name (data_name), image size (image_size), image name suffix (img_name_suffix), label ID (label_id), model type (model_type), checkpoint path (checkpoint), device (device), and random seed (seed).
    • The args variable stores the parsed arguments.
  3. Initializing variables and setting up the SAM model:

    • Several empty lists (imgs, gts, img_embeddings) are created to store processed images, ground truths, and image embeddings.
    • The SAM model is set up using the sam_model_registry and the provided model_type and checkpoint arguments.
    • The SAM model is moved to the specified device.
  4. Defining the process function:

    • The process function is defined to preprocess each ground truth and image pair.
    • It takes the names of the ground truth and image as input.
    • The function loads the ground truth data using io.imread and converts it to a 2D array if necessary.
    • The ground truth is resized to the specified image_size using transform.resize.
    • If the sum of the ground truth data is greater than 100, further processing is performed.
    • The image data is loaded using io.imread and converted to RGB if necessary.
    • Additional preprocessing steps are applied to the image data, including clipping, min-max normalization, and resizing.
    • The preprocessed image data is added to the imgs list.
    • The ground truth is added to the gts list.
    • The image is resized to the appropriate size for the SAM model using ResizeLongestSide.
    • The image tensor is preprocessed using sam_model.preprocess.
    • The image embedding is computed using the SAM model’s image encoder.
    • The image embedding is added to the img_embeddings list.
  5. Processing images and ground truths:

    • If the csv argument is provided, the code reads the CSV file using pd.read_csv and iterates over the rows.
    • For each row, the process function is called with the ground truth filename and image filename from the CSV.
    • If the csv argument is not provided, the code gets the names of all images in the ground truth folder and iterates over them.
    • For each image, the process function is called with the ground truth filename set to the current image name and the image name set to None.
  6. Saving the npz files:

    • A directory is created to save the npz files based on the npz_path and model_type arguments.
    • The processed images, ground truths, and image embeddings are stacked into numpy arrays.
    • If there is more than one image, the arrays are saved as a compressed npz file using np.savez_compressed.
    • An example image is randomly selected

    The final outputs of the image encoder will be shape of B,256,64,64.

    Whole week worked on the converting grey and rbg images of desired shape , we need for the decoder.