오늘은 객체 탐지를 위해 Transformer를 활용하는 모델인 DETR (End-to-End Object Detection with Transformers, DEtection TRansformer)모델을 알아보고 실습해보려고 합니다.
DETR은 기존에 자연어처리 분야에서 많이 쓰이던 Transformer가 객체 탐지에도 활용될 수 있다는 가능성을 열어준 모델로, self-attention을 통해 이미지 내 객체 간 관계를 효과적으로 학습하고, 복잡한 후처리 과정 없이 한 번에 객체 탐지와 분류를 수행할 수 있는 end-to-end 모델입니다.
from torch import nn
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim=256, nheads=8,
num_encoder_layers=6, num_decoder_layers=6, num_queries=100):
super().__init__()
# (1) Create ResNet-50 backbone
self.backbone = resnet50()
del self.backbone.fc
# (2) Create conversion layer
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# (3) Create a default PyTorch transformer
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers) # num_encoder_layers, num_decoder_layers
# (4) Prediction heads, one extra class for predicting non-empty slots
# note that in baseline DETR linear_bbox layer is 3-layer MLP
self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # num_classes
self.linear_bbox = nn.Linear(hidden_dim, 4)
# (5) Output positional encodings (object queries)
self.query_pos = nn.Parameter(torch.rand(num_queries, hidden_dim)) # num_queries, hidden_dim
# spatial positional encodings
# note that in baseline DETR we use sine positional encodings
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
# (5) Forward Pass
def forward(self, inputs):
# propagate inputs through ResNet-50 up to avg-pool layer
x = self.backbone.conv1(inputs)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# convert from 2048 to 256 feature planes for the transformer
h = self.conv(x)
# construct positional encodings
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# propagate through the transformer
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1)).transpose(0, 1)
# finally project transformer outputs to class labels and bounding boxes
pred_logits = self.linear_class(h) # self.linear_class(h)
pred_boxes = self.linear_bbox(h).sigmoid() # self.linear_bbox(h).sigmoid()
return {'pred_logits': pred_logits,
'pred_boxes': pred_boxes}
공식 문서를 토대로 Pytorch를 활용하여 DETR을 구현한 코드입니다. 코드 내 주석으로 넘버링을 해두었고, 아래에 좀 더 자세한 코드 설명을 적어보았습니다.
(1) Create ResNet-50 backbone
ResNet-50을 사용하여 이미지를 Feature map으로 변환합니다. self.backbone은 ResNet-50을 사용하되, 마지막 fc layer는 제거하여 추출한 feature만을 출력합니다.
(2) Create conversion layer
Conv2d를 사용하여 ResNet의 출력 채널 수(2048) 를 트랜스포머 입력 크기(hidden_dim=256)으로 줄입니다.
(3) Create a default PyTorch transformer
self.transformer는 PyTorch의 트랜스포머 모듈을 사용하며, 인코더와 디코더 층을 각각 num_encoder_layers, num_decoder_layers로 지정합니다.
트랜스포머는 이미지의 feature map을 받아 객체 탐지를 위한 전역적인 정보를 학습합니다.
(4) Prediction heads, one extra class
linear_class: hidden_dim 차원의 출력으로 클래스를 예측하기 위한 선형 레이어입니다. 'No object'를 예측할 수 있도록 클래스 개수보다 하나 더 많은 출력을 생성합니다.
linear_bbox: hidden_dim에서 4차원 벡터로 출력해 Bounding Box를 예측합니다. 여기서 (x, y, w, h) 는 이미지 크기를 0에서 1로 정규화한 상대적인 위치과 크기입니다.
(5) Output positional encodings
query_pos: 객체 탐지 쿼리로, 트랜스포머 디코더의 입력입니다.
row_embed, col_emed: 공간 위치에 대한 positional encoding을 제공합니다.
(6) Forward Pass
ResNet-50의 피쳐 맵을 추출하여 트랜스포머 입력 크기로 변환
공간 위치 정보를 포함한 pos를 생성해 피쳐 맵에 더한 후, 트랜스포머에 입력으로 전달
트랜스포머의 출력을 linear_class와 linear_bbox에 입력하여 class와 bounding box를 예측.
다음 실습에서는 직접 제가 찍은 이미지로 객체 탐지를 해보겠습니다.
참조:
https://github.com/facebookresearch/detr
GitHub - facebookresearch/detr: End-to-End Object Detection with Transformers
End-to-End Object Detection with Transformers. Contribute to facebookresearch/detr development by creating an account on GitHub.
github.com