본문 바로가기
인공지능/Deep Learning

[Deep Learning] DETR 모델 이해하고 실습하기 (1)

by 이준언 2024. 11. 13.

오늘은 객체 탐지를 위해 Transformer를 활용하는 모델인 DETR (End-to-End Object Detection with Transformers, DEtection TRansformer)모델을 알아보고 실습해보려고 합니다. 

DETR은 기존에 자연어처리 분야에서 많이 쓰이던 Transformer가 객체 탐지에도 활용될 수 있다는 가능성을 열어준 모델로, self-attention을 통해 이미지 내 객체 간 관계를 효과적으로 학습하고, 복잡한 후처리 과정 없이 한 번에 객체 탐지와 분류를 수행할 수 있는 end-to-end 모델입니다. 

from torch import nn
class DETR(nn.Module):
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6, num_queries=100):
        super().__init__()
      
        # (1) Create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # (2) Create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # (3) Create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers) # num_encoder_layers, num_decoder_layers

        # (4) Prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # num_classes
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # (5) Output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(num_queries, hidden_dim)) # num_queries, hidden_dim

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

        # (5) Forward Pass
    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # propagate through the transformer
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1)).transpose(0, 1)



        # finally project transformer outputs to class labels and bounding boxes
        pred_logits = self.linear_class(h) # self.linear_class(h)
        pred_boxes = self.linear_bbox(h).sigmoid() # self.linear_bbox(h).sigmoid()

        return {'pred_logits': pred_logits,
                'pred_boxes': pred_boxes}

공식 문서를 토대로 Pytorch를 활용하여 DETR을 구현한 코드입니다. 코드 내 주석으로 넘버링을 해두었고, 아래에 좀 더 자세한 코드 설명을 적어보았습니다. 

(1) Create ResNet-50 backbone

ResNet-50을 사용하여 이미지를 Feature map으로 변환합니다. self.backbone은 ResNet-50을 사용하되, 마지막 fc layer는 제거하여 추출한 feature만을 출력합니다.

(2) Create conversion layer

Conv2d를 사용하여 ResNet의 출력 채널 수(2048) 를 트랜스포머 입력 크기(hidden_dim=256)으로 줄입니다.

(3) Create a default PyTorch transformer

self.transformer는 PyTorch의 트랜스포머 모듈을 사용하며, 인코더와 디코더 층을 각각 num_encoder_layers, num_decoder_layers로 지정합니다. 

트랜스포머는 이미지의 feature map을 받아 객체 탐지를 위한 전역적인 정보를 학습합니다. 

(4) Prediction heads, one extra class

linear_class: hidden_dim 차원의 출력으로 클래스를 예측하기 위한 선형 레이어입니다. 'No object'를 예측할 수 있도록 클래스 개수보다 하나 더 많은 출력을 생성합니다. 

linear_bbox: hidden_dim에서 4차원 벡터로 출력해 Bounding Box를 예측합니다. 여기서 (x, y, w, h) 는 이미지 크기를 0에서 1로 정규화한 상대적인 위치과 크기입니다. 

(5) Output positional encodings

query_pos: 객체 탐지 쿼리로, 트랜스포머 디코더의 입력입니다.

row_embed, col_emed: 공간 위치에 대한 positional encoding을 제공합니다.

(6) Forward Pass

ResNet-50의 피쳐 맵을 추출하여 트랜스포머 입력 크기로 변환

공간 위치 정보를 포함한 pos를 생성해 피쳐 맵에 더한 후, 트랜스포머에 입력으로 전달

트랜스포머의 출력을 linear_class linear_bbox 입력하여 class bounding box 예측.


다음 실습에서는 직접 제가 찍은 이미지로 객체 탐지를 해보겠습니다. 

 

참조:

https://github.com/facebookresearch/detr

 

GitHub - facebookresearch/detr: End-to-End Object Detection with Transformers

End-to-End Object Detection with Transformers. Contribute to facebookresearch/detr development by creating an account on GitHub.

github.com