AI/Vision
Vision-AI) Mask 정리해서 시각화하기(one-hot, show_channel)
MightyTedKim
2024. 5. 8. 03:46
728x90
반응형
semantic segmentation에서
mask가 제대로 만들어졌는지 궁금할 때가 있어요.
이걸 위해서 MaskViewer라는 class를 만들었어요.
원본 이미지와 mask 이미지만 보려고 할때는 아래처럼 show_channeld=False를 입력하면 됩니다.
채널 별로 보고 싶으면 show_channels=True를 입력하면 되요
코드는 아래와 같습니다.
class MaskViewer:
def __init__(self, input_img_dir, input_mask_dir, cat_names):#cat_names
self.input_img_dir = input_img_dir
self.input_mask_dir = input_mask_dir
self.cat_names = cat_names
def is_image_file(self, filename):
valid_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"]
return any(filename.lower().endswith(ext) for ext in valid_extensions)
def visualize_images(self, original, mask):
""" Visualize the original image and mask side by side. """
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap='nipy_spectral')
plt.title('Mask Image(one-hot-encoded)')
plt.axis('off')
plt.show()
def visualize_images_channels(self, image, channel_names):
"""Visualize each channel of a multi-channel image."""
num_channels = image.shape[-1]
plt.figure(figsize=(10, 3))
for i in range(num_channels):
plt.subplot(1, num_channels, i + 1)
plt.imshow(image[..., i], cmap='gray')
plt.title(f'{i}: {channel_names[i]}', fontsize=6)
plt.axis('off')
plt.tight_layout()
plt.show()
def visualize_image_set(self, filename, show_channels):
"""Visualize a set of images (original, mask, and potentially channels) given a filename."""
def convert_to_one_hot(mask, num_classes):
"""Convert a grayscale mask with class indices to a one-hot encoded multi-channel format."""
shape = mask.shape + (num_classes,)
one_hot = np.zeros(shape, dtype=np.uint8)
for i in range(num_classes):
one_hot[..., i] = (mask == i).astype(np.uint8)
return one_hot
original_image_path = os.path.join(self.input_img_dir, filename)
mask_file_path = os.path.join(self.input_mask_dir, os.path.splitext(filename)[0] + '.png')
print(f"{mask_file_path}")
original_image = cv2.imread(original_image_path)
mask_image = cv2.imread(mask_file_path, cv2.IMREAD_UNCHANGED)
if original_image is None or mask_image is None:
print(f"Error loading images for {filename}.")
return
# Visualize original and mask
self.visualize_images(original_image, mask_image)
if show_channels:
mask_one_hot = convert_to_one_hot(mask_image, len(self.cat_names))
self.visualize_images_channels(mask_one_hot, self.cat_names)
def process_and_visualize_n_masks(self, n, show_channels):
"""Process and visualize up to n mask images and their corresponding original images."""
filenames = [f for f in os.listdir(self.input_img_dir) if self.is_image_file(f)]
for i, filename in enumerate(filenames):
if i >= n: # Stop after processing n images
break
self.visualize_image_set(filename, show_channels)
mask를 이용하시는 분들께 도움이 되엇으면 좋겠습니다.
728x90
반응형