AI/Vision
Vision-AI) prediction mask 시각화하기 (train용 mask와 class별로 비교)
MightyTedKim
2024. 5. 8. 04:00
728x90
반응형
이전 포스팅에서는 학습 전 mask를 확인하는 코드를 소개했어요
https://mightytedkim.tistory.com/213
이번에는 prediction도 함께 비교하는 코드를 소개하려고 해요
이전 포스팅에서의 MaskViewer를 사용해도, prediction mask를 똑같이 볼 수 있어요
하지만 어떤 class가 매칭이되는지 한눈에 보고 싶어서 DIR_PREDICT 변수를 하나 더 추가했어요
참고로 각가의 변수들은 다음과 같아요
코드는 아래와 같습니다.
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
class ImageVisualizer:
def __init__(self, original_dir, mask_dir, prediction_dir, cat_names, max_images=None):
self.original_dir = original_dir
self.mask_dir = mask_dir
self.prediction_dir = prediction_dir
self.cat_names = cat_names
self.max_images = max_images if max_images is not None else float('inf')
@staticmethod
def is_image_file(filename):
"""Check if the file is an image based on its extension."""
valid_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"]
return any(filename.lower().endswith(ext) for ext in valid_extensions)
def convert_to_one_hot(self, mask, num_classes):
"""Convert a grayscale mask with class indices to a one-hot encoded multi-channel format."""
shape = mask.shape + (num_classes,)
one_hot = np.zeros(shape, dtype=np.uint8)
for i in range(num_classes):
one_hot[..., i] = (mask == i).astype(np.uint8)
return one_hot
def visualize_images_channels(self, image, channel_names):
"""Visualize each channel of a multi-channel image."""
num_channels = image.shape[-1]
plt.figure(figsize=(25, 5))
for i in range(num_channels):
plt.subplot(1, num_channels, i + 1)
plt.imshow(image[..., i], cmap='gray')
plt.title(f'{i}: {channel_names[i]}', fontsize=6)
plt.axis('off')
# plt.tight_layout()
plt.show()
def visualize_images(self, original, mask, prediction, titles):
"""Visualize three images side by side: original, mask, and prediction."""
images = [original, mask, prediction]
plt.figure(figsize=(15, 5))
for i, img in enumerate(images):
plt.subplot(3, 3, i + 1)
if i == 0:
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
else:
plt.imshow(img, cmap='nipy_spectral')
plt.title(titles[i], fontsize=10)
plt.axis('off')
plt.show()
def process_images(self, show_channels=False):
count = 0
for filename in os.listdir(self.original_dir):
if count >= self.max_images:
break
if self.is_image_file(filename):
original_file_path = os.path.join(self.original_dir, filename)
mask_file_path = os.path.join(self.mask_dir, os.path.splitext(filename)[0] + '.png')
prediction_file_path = os.path.join(self.prediction_dir, os.path.splitext(filename)[0] + '.png')
original_image = cv2.imread(original_file_path)
mask_image = cv2.imread(mask_file_path, cv2.IMREAD_UNCHANGED)
predicted_image = cv2.imread(prediction_file_path, cv2.IMREAD_UNCHANGED)
if original_image is None or mask_image is None or predicted_image is None:
print(f"Error loading one or more images for {filename}.")
continue
# 3 images
# titles = [f'Origin:\n {original_file_path}', f'Mask:\n {mask_file_path}', f'Predict:\n {prediction_file_path}']
titles = [f'Origin:', f'Mask', f'Predict']
print(prediction_file_path)
self.visualize_images(original_image, mask_image, predicted_image,titles)
# channels
if show_channels:
if mask_image.ndim == 3 and mask_image.shape[-1] > 3:
self.visualize_images_channels(mask_image, self.cat_names)
elif mask_image.ndim == 2:
mask_one_hot = self.convert_to_one_hot(mask_image, len(self.cat_names))
self.visualize_images_channels(mask_one_hot, self.cat_names)
if predicted_image.ndim == 3 and predicted_image.shape[-1] > 3:
self.visualize_images_channels(predicted_image, self.cat_names)
elif predicted_image.ndim == 2:
prediction_one_hot = self.convert_to_one_hot(predicted_image, len(self.cat_names))
self.visualize_images_channels(prediction_one_hot, self.cat_names)
count += 1
semantic segmentation 진행하시는 분들께 도움이 되었으면 좋겠습니다.
728x90
반응형