'데이터 엔지니어'로 성장하기

정리하는 걸 좋아하고, 남이 읽으면 더 좋아함

AI/Vision

Vision-AI) Mask 정리해서 시각화하기(one-hot, show_channel)

MightyTedKim 2024. 5. 8. 03:46
728x90
반응형

semantic segmentation에서

mask가 제대로 만들어졌는지 궁금할 때가 있어요.

 

이걸 위해서 MaskViewer라는 class를 만들었어요.

 

원본 이미지와 mask 이미지만 보려고 할때는 아래처럼 show_channeld=False를 입력하면 됩니다.

 

채널 별로 보고 싶으면 show_channels=True를 입력하면 되요

 

코드는 아래와 같습니다.

class MaskViewer:
    def __init__(self, input_img_dir, input_mask_dir, cat_names):#cat_names
        self.input_img_dir = input_img_dir
        self.input_mask_dir = input_mask_dir
        self.cat_names = cat_names

    def is_image_file(self, filename):
        valid_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"]
        return any(filename.lower().endswith(ext) for ext in valid_extensions)

    def visualize_images(self, original, mask):
        """ Visualize the original image and mask side by side. """
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
        plt.title('Original Image')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(mask, cmap='nipy_spectral')
        plt.title('Mask Image(one-hot-encoded)')
        plt.axis('off')
        plt.show()
    
    def visualize_images_channels(self, image, channel_names):
        """Visualize each channel of a multi-channel image."""
        num_channels = image.shape[-1]
        plt.figure(figsize=(10, 3))
        for i in range(num_channels):
            plt.subplot(1, num_channels, i + 1)
            plt.imshow(image[..., i], cmap='gray')
            plt.title(f'{i}: {channel_names[i]}', fontsize=6)
            plt.axis('off')
        plt.tight_layout()
        plt.show()

    def visualize_image_set(self, filename, show_channels):
        """Visualize a set of images (original, mask, and potentially channels) given a filename."""
        def convert_to_one_hot(mask, num_classes):
            """Convert a grayscale mask with class indices to a one-hot encoded multi-channel format."""
            shape = mask.shape + (num_classes,)
            one_hot = np.zeros(shape, dtype=np.uint8)
            for i in range(num_classes):
                one_hot[..., i] = (mask == i).astype(np.uint8)
            return one_hot
        
        original_image_path = os.path.join(self.input_img_dir, filename)
        mask_file_path = os.path.join(self.input_mask_dir, os.path.splitext(filename)[0] + '.png')
        print(f"{mask_file_path}")

        original_image = cv2.imread(original_image_path)
        mask_image = cv2.imread(mask_file_path, cv2.IMREAD_UNCHANGED)

        if original_image is None or mask_image is None:
            print(f"Error loading images for {filename}.")
            return

        # Visualize original and mask
        self.visualize_images(original_image, mask_image)
        if show_channels:
            mask_one_hot = convert_to_one_hot(mask_image, len(self.cat_names))
            self.visualize_images_channels(mask_one_hot, self.cat_names)

    def process_and_visualize_n_masks(self, n, show_channels):
        """Process and visualize up to n mask images and their corresponding original images."""
        filenames = [f for f in os.listdir(self.input_img_dir) if self.is_image_file(f)]
        for i, filename in enumerate(filenames):
            if i >= n:  # Stop after processing n images
                break
            self.visualize_image_set(filename, show_channels)

 

mask를 이용하시는 분들께 도움이 되엇으면 좋겠습니다.

728x90
반응형