AI/Vision
Vision-AI) coco dataset 시각화 하기 (코드 포함)
MightyTedKim
2024. 5. 8. 03:28
728x90
반응형
coco dataset을 시각화하는데는 많은 방법이 있습니다
이번 포스팅에서는 소개하려고 합니다.
예상 독자는 아래와 같습니다
1. coco datset 테스트가 필요하신 분
2. 바로 사용할 수 있는 정리된 class가 필요하신 분
목차는 아래와 같습니다.
1. matplotlib 사용
pycoco도 많이 사용하시는데, 저는 불편하더라고요. id값이 int가 아니면 오류를 뱉는다던가 제약들이 있어가지고.
그래서 제가 편하게 사용하려고 아래 코드를 만들었습니다.
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from collections import Counter
class CocoViewer:
def __init__(self, annotation_dir, num_images=3, min_points=10, max_points=22504, show_text=True):
self.annotation_dir = annotation_dir
self.num_images = num_images
self.min_points = min_points
self.max_points = max_points
self.annotation_file_path = os.path.join(self.annotation_dir, "annotations", [f for f in os.listdir(os.path.join(self.annotation_dir, "annotations")) if f.endswith('.json')][0])
self.annotations = self.load_coco_annotations()
self.categories = {category['id']: category for category in self.annotations['categories']}
self.category_names = list({category['name']: category for category in self.annotations['categories']}.keys())
self.initialize_color_palette()
self.show_text = show_text
def load_coco_annotations(self):
with open(self.annotation_file_path, 'r') as file:
return json.load(file)
def analyze_annotation(self):
# Visualize category distribution
id_to_name = {cat['id']: cat['name'] for cat in self.annotations['categories']} # Mapping for category names
category_names = [id_to_name[annotation['category_id']] for annotation in self.annotations['annotations'] if 'category_id' in annotation]
category_counts = Counter(category_names)
names = list(category_counts.keys())
counts = list(category_counts.values())
plt.figure(figsize=(7, 3))
bars = plt.barh(names, counts, color='skyblue')
plt.xlabel('Number of Annotations')
plt.ylabel('Category Name')
plt.title('Category Distribution in Annotations')
for bar in bars:
plt.text(bar.get_width(), bar.get_y() + bar.get_height() / 2,
f'{int(bar.get_width())}',
va='center', ha='left')
plt.tight_layout()
plt.show()
def initialize_color_palette(self):
cmap = plt.get_cmap('tab20') # You can change this colormap as needed
self.colors = cmap(np.linspace(0, 1, cmap.N)) # Extract as many colors as are available in the colormap
self.color_palette = {}
self.color_index = 0
def get_color_for_category(self, category_id):
if category_id not in self.color_palette:
# Assign a new color and increment the index
self.color_palette[category_id] = self.colors[self.color_index % len(self.colors)]
self.color_index += 1
return self.color_palette[category_id]
def visualize_image(self, n=1):
"""Visualize the first n images with annotations."""
image_dir = os.path.join(self.annotation_dir, 'images')
images_processed = 0
for image_info in self.annotations['images']:
if images_processed >= n:
break
image_path = os.path.join(image_dir, image_info['file_name'])
self.draw_segmentations(image_path, image_info['id'])
images_processed += 1
def draw_segmentations(self, image_path, image_id):
print(f"Processing image: {image_path}")
image = cv2.imread(image_path)
if image is None:
print(f"Failed to load image {image_path}")
return
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(image)
for annotation in self.annotations['annotations']:
if annotation['image_id'] == image_id:
category_id = annotation['category_id']
category_name = self.categories[category_id]['name']
color = self.get_color_for_category(category_id)
if category_name not in self.category_names:
print(f"Category ID: {category_id}, Color: {color}")
break
for seg in annotation['segmentation']:
if self.min_points < len(seg) <= self.max_points:
seg_np = np.array(seg).reshape((-1, 2))
poly = Polygon(seg_np, edgecolor='none', facecolor=color, fill=True, lw=2, alpha=1.0)#0.4)
ax.add_patch(poly)
if self.show_text:
rx, ry = seg_np.mean(axis=0)
ax.text(rx, ry, f'{category_id}: {category_name}', color='white', weight='bold', fontsize=12, ha='center', va='center')
plt.axis('off')
plt.show()
실행하면 아래처럼 나옵니다.
text를 보여주는 경우에는, min/max_points를 적절히 조절해주는 것이 필요합니다.
min/max point 값의 분포를 보기 위해서는, 분포도를 보는 것도 필요합니다.
728x90
반응형