Simple Object Detection with DETR
Detection Transformer (DETR) 을 이용해서 간단한 Object Detection을 수행해보자.
본 코드는 "10가지 프로젝트로 끝내는 트랜스포머 활용 가이드 with 파이토치", 루비페이퍼, 프렘 팀시나 지음, 임선집 옮김 채호창 감수에 있는 코드를 그대로 가져와서 실행만 해 본 결과이다.
먼저 하고자 하는 것은 아래 사진에서 사람, 차량 등을 검출해내는 것이다.
결과적으로 동작하는 코드는 아래와 같다.
import torch
import torchvision.transforms as T
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from transformers import DetrImageProcessor, DetrForObjectDetection, DetrConfig
#########################
# 1. get image
img_path = r'D:\example\free-photo-of-people-crossing-the-road.jpeg'
img = Image.open(img_path)
img = img.convert('RGB')
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img_tensor = transform(img).unsqueeze(0)
########################
# 2. get model
config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
model.eval()
########################
# 3. predict
with torch.no_grad():
outputs = model(img_tensor)
target_sizes = torch.tensor([img.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
########################
# 4. visualize
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(img)
colors = plt.get_cmap("tab20").colors
# results["scores"] 는 예측
# results["labels"] 는 레이블
# results["boxes"] 는 객체의 경계 상자
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
x,y,w,h = box
w = w - x
h = h - y
rect = plt.Rectangle((x,y), w, h, linewidth=1, edgecolor=colors[label%20], facecolor='none')
ax.add_patch(rect)
ax.text(
x, y,
f"{model.config.id2label[label.item()]}{round(score.item(), 2)}",
fontsize=15,
color=colors[label%20]
)
plt.show()
일단 코드는 크게 4 부분으로 나누어진다.
1. 이미지를 읽는다.
2. 이미 학습된 모델을 가져온다.
3. predict를 수행한다.
4. 결과를 화면에 보인다.
세부적으로 각 단계를 보자.
1. 이미지를 읽는다.
먼저 이미지를 읽는 부분을 보면 아래와 같은 코드가 있다.
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img_tensor = transform(img).unsqueeze(0)
이 코드는 아래와 같은 역할을 한다.
1) transform = T.Compose([...])
- Compose는 여러 개의 변환(transform)을 순차적으로 묶어서 실행할 때 사용하는 함수
- 여기선 세 가지 전처리를 순서대로 적용함
- T.Resize(800)
- 이미지를 한쪽 변(짧은 쪽 기준)을 800픽셀로 리사이즈함
- 비율 유지하면서 리사이즈 됨
- T.ToTensor()
- PIL 이미지나 numpy array를 파이토치 텐서로 변환.
- 동시에 픽셀값을 0~255 를 0.0 ~ 1.0 범위로 정규화해줌.
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- 텐서 형태의 이미지 데이터를 평균과 표준편차로 정규화
- 여기의 값은 ImageNet 데이터셋에서 학습할 때 쓴 평균/표준편차 값임
- 보통 ImageNet pretrained 모델을 쓸 때 이 값으로 normalize해줘야 모델이 잘 동작함
2) img_tensor = transform(img).unsqueeze(0)
- transform(img) : 위에서 만든 transform 파이프라인을 이미지 img에 적용.
- .unsqueeze(0) :
- 텐서의 **0번째 차원(batch dimension)**을 추가.
- 원래 (C, H, W) → (1, C, H, W)
- 모델에 넣을 때 batch 형태로 만들어주는 과정.
즉, 이미지를 적절한 크기로 resize하고 tensor로 변환하고 normalize 해 준 뒤, unsqueeze를 통해서 batch 형태로 넣을 수 있게 변환해준 것이다. 물론 현재 예제에서는 batch 를 적용하고 있지는 않다.
2. 학습된 모델을 가져온다.
아래 코드에서 이미 학습된 모델을 그대로 가져온다.
########################
# 2. get model
config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
model.eval()
여기서는 facebook/detr-resnet-50 모델을 이미 pretrained 된 것을 가져와서 바로 사용하는 것으로 되어 있다.
여기서 config 는 뭔고 하니... 지피티에게 물어보니 아래와 같이 알려준다.
📌 config란?
config는 **모델의 설정 값들(hyperparameter나 설정 정보)**을 담고 있는 객체예요.
예를 들면:
- hidden layer 개수
- attention head 개수
- position embedding 방식
- class 개수
- dropout 비율
- 기타 모델 동작에 필요한 설정 값들
모델 아키텍처와 관련된 정보가 저장돼 있어, 모델을 불러오거나 새로 만들 때 참고하는 설정 파일이라고 보면 돼요.
흠. 그럼 대충 모델의 정보를 가지고 있는 설정객체 정도로 이해하면 될 것 같다.
3. predict를 수행한다.
이미지를 읽어왔고, 모델도 가져왔으니 predict를 수행해보자.
좀 더 풀어쓰면 이미지 내에서 객체를 찾아보자.
with torch.no_grad():
outputs = model(img_tensor)
target_sizes = torch.tensor([img.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
outputs = model(img_tensor) 에서 이미지를 모델의 입력으로 넣어서 outputs를 가져오는 부분을 볼 수 있다.
4. 결과를 화면에 보인다.
마지막 코드 부분은 단순히 입력 이미지에 객체 검출을 수행한 결과를 덧그리는 작업이다.
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(img)
colors = plt.get_cmap("tab20").colors
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
x,y,w,h = box
w = w - x
h = h - y
rect = plt.Rectangle((x,y), w, h, linewidth=1, edgecolor=colors[label%20], facecolor='none')
ax.add_patch(rect)
ax.text(
x, y,
f"{model.config.id2label[label.item()]}{round(score.item(), 2)}",
fontsize=15,
color=colors[label%20]
)
plt.show()
이미지 레이블별로 색깔을 다르게 해주기 위해서 colors를 가져오고, results 의 label 에 따라 색깔을 다르게 칠해주는 코드를 볼 수 있다.
최종적으로 실행한 결과는 아래와 같다. 사람과 차량, handbag, traffic light 등이 검출된 것을 볼 수 있다.
코드 몇줄 안되지만, 이미지에서 사람, 차량, 핸드백 등을 잘 검출한 것을 볼 수 있다.
이미 있는 모델을 잘 활용하면 개발 속도가 훨씬 빨라질 듯 하다.
그럼 이만.
코드 출처: 10가지 프로젝트로 끝내는 트랜스포머 활용 가이드 with 파이토치, 루비페이퍼, 프렘 팀시나 지음.