'데이터 엔지니어'로 성장하기

정리하는 걸 좋아하고, 남이 읽으면 더 좋아함

AI/Vision

Vision-AI) coco dataset 시각화 하기 (코드 포함)

MightyTedKim 2024. 5. 8. 03:28
728x90
반응형

coco dataset을 시각화하는데는 많은 방법이 있습니다

이번 포스팅에서는 소개하려고 합니다.

 

예상 독자는 아래와 같습니다

1. coco datset 테스트가 필요하신 분

2. 바로 사용할 수 있는 정리된 class가 필요하신 분

 

목차는 아래와 같습니다.

1. matplotlib 사용

 

pycoco도 많이 사용하시는데, 저는 불편하더라고요. id값이 int가 아니면 오류를 뱉는다던가 제약들이 있어가지고.

그래서 제가 편하게 사용하려고 아래 코드를 만들었습니다.

import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from collections import Counter

class CocoViewer:
    def __init__(self, annotation_dir, num_images=3, min_points=10, max_points=22504, show_text=True):
        self.annotation_dir = annotation_dir
        self.num_images = num_images
        self.min_points = min_points
        self.max_points = max_points
        self.annotation_file_path = os.path.join(self.annotation_dir, "annotations", [f for f in os.listdir(os.path.join(self.annotation_dir, "annotations")) if f.endswith('.json')][0])
        self.annotations = self.load_coco_annotations()
        self.categories = {category['id']: category for category in self.annotations['categories']}
        self.category_names = list({category['name']: category for category in self.annotations['categories']}.keys())
        self.initialize_color_palette()
        self.show_text = show_text
        
    def load_coco_annotations(self):
        with open(self.annotation_file_path, 'r') as file:
            return json.load(file)

    def analyze_annotation(self):
        # Visualize category distribution
        id_to_name = {cat['id']: cat['name'] for cat in self.annotations['categories']}  # Mapping for category names
        category_names = [id_to_name[annotation['category_id']] for annotation in self.annotations['annotations'] if 'category_id' in annotation]
        category_counts = Counter(category_names)

        names = list(category_counts.keys())
        counts = list(category_counts.values())

        plt.figure(figsize=(7, 3))
        bars = plt.barh(names, counts, color='skyblue')
        plt.xlabel('Number of Annotations')
        plt.ylabel('Category Name')
        plt.title('Category Distribution in Annotations')
        
        for bar in bars:
            plt.text(bar.get_width(), bar.get_y() + bar.get_height() / 2,
                     f'{int(bar.get_width())}', 
                     va='center', ha='left')

        plt.tight_layout()
        plt.show()
        
    def initialize_color_palette(self):
        cmap = plt.get_cmap('tab20')  # You can change this colormap as needed
        self.colors = cmap(np.linspace(0, 1, cmap.N))  # Extract as many colors as are available in the colormap
        self.color_palette = {}
        self.color_index = 0

    def get_color_for_category(self, category_id):
        if category_id not in self.color_palette:
            # Assign a new color and increment the index
            self.color_palette[category_id] = self.colors[self.color_index % len(self.colors)]
            self.color_index += 1
        return self.color_palette[category_id]

    def visualize_image(self, n=1):
        """Visualize the first n images with annotations."""
        image_dir = os.path.join(self.annotation_dir, 'images')
        images_processed = 0

        for image_info in self.annotations['images']:
            if images_processed >= n:
                break
            image_path = os.path.join(image_dir, image_info['file_name'])
            self.draw_segmentations(image_path, image_info['id'])
            images_processed += 1

    def draw_segmentations(self, image_path, image_id):
        print(f"Processing image: {image_path}")
        image = cv2.imread(image_path)
        if image is None:
            print(f"Failed to load image {image_path}")
            return

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(image)

        for annotation in self.annotations['annotations']:
            if annotation['image_id'] == image_id:
                category_id = annotation['category_id']
                category_name = self.categories[category_id]['name']
                color = self.get_color_for_category(category_id)
                if category_name not in self.category_names:
                    print(f"Category ID: {category_id}, Color: {color}")
                    break
                for seg in annotation['segmentation']:
                    if self.min_points < len(seg) <= self.max_points:
                        seg_np = np.array(seg).reshape((-1, 2))
                        poly = Polygon(seg_np, edgecolor='none', facecolor=color, fill=True, lw=2, alpha=1.0)#0.4)
                        ax.add_patch(poly)
                        if self.show_text:
                            rx, ry = seg_np.mean(axis=0)
                            ax.text(rx, ry, f'{category_id}: {category_name}', color='white', weight='bold', fontsize=12, ha='center', va='center')

        plt.axis('off')
        plt.show()

 

실행하면 아래처럼 나옵니다.

text를 보여주는 경우에는, min/max_points를 적절히 조절해주는 것이 필요합니다.

 

min/max point 값의 분포를 보기 위해서는, 분포도를 보는 것도 필요합니다.

 

 

 

https://fictitious.tistory.com/2

728x90
반응형