'데이터 엔지니어'로 성장하기

정리하는 걸 좋아하고, 남이 읽으면 더 좋아함

AI/Vision

Vision-AI) prediction mask 시각화하기 (train용 mask와 class별로 비교)

MightyTedKim 2024. 5. 8. 04:00
728x90
반응형

이전 포스팅에서는 학습 전 mask를 확인하는 코드를 소개했어요

https://mightytedkim.tistory.com/213

 

Vision-AI) Mask 정리해서 시각화하기(one-hot, show_channel)

semantic segmentation에서mask가 제대로 만들어졌는지 궁금할 때가 있어요. 이걸 위해서 MaskViewer라는 class를 만들었어요. 원본 이미지와 mask 이미지만 보려고 할때는 아래처럼 show_channeld=False를 입력

mightytedkim.tistory.com

 

이번에는 prediction도 함께 비교하는 코드를 소개하려고 해요

 

이전 포스팅에서의 MaskViewer를 사용해도, prediction mask를 똑같이 볼 수 있어요

 

하지만 어떤 class가 매칭이되는지 한눈에 보고 싶어서 DIR_PREDICT 변수를 하나 더 추가했어요

 

참고로 각가의 변수들은 다음과 같아요

 

코드는 아래와 같습니다.

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

class ImageVisualizer:
    def __init__(self, original_dir, mask_dir, prediction_dir, cat_names, max_images=None):
        self.original_dir = original_dir
        self.mask_dir = mask_dir
        self.prediction_dir = prediction_dir
        self.cat_names = cat_names
        self.max_images = max_images if max_images is not None else float('inf')

    @staticmethod
    def is_image_file(filename):
        """Check if the file is an image based on its extension."""
        valid_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"]
        return any(filename.lower().endswith(ext) for ext in valid_extensions)

    def convert_to_one_hot(self, mask, num_classes):
        """Convert a grayscale mask with class indices to a one-hot encoded multi-channel format."""
        shape = mask.shape + (num_classes,)
        one_hot = np.zeros(shape, dtype=np.uint8)
        for i in range(num_classes):
            one_hot[..., i] = (mask == i).astype(np.uint8)
        return one_hot

    def visualize_images_channels(self, image, channel_names):
        """Visualize each channel of a multi-channel image."""
        num_channels = image.shape[-1]
        plt.figure(figsize=(25, 5))
        for i in range(num_channels):
            plt.subplot(1, num_channels, i + 1)
            plt.imshow(image[..., i], cmap='gray')
            plt.title(f'{i}: {channel_names[i]}', fontsize=6)
            plt.axis('off')
        # plt.tight_layout()
        plt.show()

    def visualize_images(self, original, mask, prediction, titles):
        """Visualize three images side by side: original, mask, and prediction."""
        images = [original, mask, prediction]
        plt.figure(figsize=(15, 5))
        for i, img in enumerate(images):
            plt.subplot(3, 3, i + 1)
            if i == 0:
                plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            else:
                plt.imshow(img, cmap='nipy_spectral')
            plt.title(titles[i], fontsize=10)
            plt.axis('off')
        plt.show()

    def process_images(self, show_channels=False):
        count = 0
        for filename in os.listdir(self.original_dir):
            if count >= self.max_images:
                break
            if self.is_image_file(filename):
                original_file_path = os.path.join(self.original_dir, filename)
                mask_file_path = os.path.join(self.mask_dir, os.path.splitext(filename)[0] + '.png')
                prediction_file_path = os.path.join(self.prediction_dir, os.path.splitext(filename)[0] + '.png')

                original_image = cv2.imread(original_file_path)
                mask_image = cv2.imread(mask_file_path, cv2.IMREAD_UNCHANGED)
                predicted_image = cv2.imread(prediction_file_path, cv2.IMREAD_UNCHANGED)

                if original_image is None or mask_image is None or predicted_image is None:
                    print(f"Error loading one or more images for {filename}.")
                    continue

                # 3 images
                # titles = [f'Origin:\n {original_file_path}', f'Mask:\n {mask_file_path}', f'Predict:\n {prediction_file_path}']
                titles = [f'Origin:', f'Mask', f'Predict']
                print(prediction_file_path)
                self.visualize_images(original_image, mask_image, predicted_image,titles)

                # channels
                if show_channels:
                    if mask_image.ndim == 3 and mask_image.shape[-1] > 3:
                        self.visualize_images_channels(mask_image, self.cat_names)
                    elif mask_image.ndim == 2:
                        mask_one_hot = self.convert_to_one_hot(mask_image, len(self.cat_names))
                        self.visualize_images_channels(mask_one_hot, self.cat_names)
    
                    if predicted_image.ndim == 3 and predicted_image.shape[-1] > 3:
                        self.visualize_images_channels(predicted_image, self.cat_names)
                    elif predicted_image.ndim == 2:
                        prediction_one_hot = self.convert_to_one_hot(predicted_image, len(self.cat_names))
                        self.visualize_images_channels(prediction_one_hot, self.cat_names)

                count += 1

 

 

semantic segmentation 진행하시는 분들께 도움이 되었으면 좋겠습니다.

728x90
반응형