Mô hình CNN kết hợp với cơ chế chú ý CBAM

 Mô hình CNN kết hợp với cơ chế chú ý CBAM



import torch
import torch.nn as nn
# Channel Attention
class ChannelAttention(nn.Module) ## kế thừa PyTorch
def __init__(self, soKenh, tile_soKenhGiam=16):
    super(ChannelAttention, self).__init__()

    # Gộp Max pooling và Trung Bình pooling về 1x1
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)

    #Tạo mạng tích chập
    self.mlp = nn.Sequential(
            nn.Linear(in_features=soKenh, out_features=soKenh // tile_soKenhGiam, bias=False),
            nn.ReLU(),
            nn.Linear(in_features=soKenh // tile_soKenhGiam, out_features=soKenh, bias=False)
    )
    self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x)) #lấy tb pooling -> qua 2 lớp FC
        max_out = self.mlp(self.max_pool(x)) #lấy max pooling -> qua 2 lớp FC
        # Cộng hai kết quả lại và đi qua hàm Sigmoid
        out = avg_out + max_out
        return self.sigmoid(out)
   
# Class Spatial Attention
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'Kernel size phải là 3 hoặc 7'
        padding = 3 if kernel_size == 7 else 1 # thêm padding
       
        # Tích chập để gộp thông tin không gian
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Lấy giá trị trung bình và lớn nhất dọc theo trục channel
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
       
        # Nối (concatenate) hai ma trận lại với nhau
        x_cat = torch.cat([avg_out, max_out], dim=1)
       
        # Đi qua lớp tích chập và hàm Sigmoid
        out = self.conv1(x_cat)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, soKenh, tile_soKenhGiam=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(soKenh, tile_soKenhGiam)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        # Nhân input với trọng số của Channel Attention
        x = x * self.channel_attention(x)
        # Nhân kết quả tiếp với trọng số của Spatial Attention
        x = x * self.spatial_attention(x)
        return x

class CNN_CBAM(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN_CBAM, self).__init__()
        # Lớp Convolution đầu tiên
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
       
        # Tích hợp CBAM ngay sau khi trích xuất đặc trưng
        self.cbam = CBAM(in_planes=64)
       
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
       
        # Classifier
        # Giả sử ảnh đầu vào là 32x32. Sau MaxPool 2x2 thì kích thước còn 16x16
        self.fc = nn.Linear(64 * 16 * 16, num_classes)

    # luồng hoạt động CNN - CBAM
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
       
        # Đi qua khối Attention
        x = self.cbam(x)
       
        x = self.pool(x)
       
        # Làm phẳng tensor để đưa vào lớp Linear
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Demo này gắn khối CBAM sau lớp Convu để khắc phục nhược điểm cục bộ của lớp convu

Đăng nhận xét

Post a Comment (0)

Mới hơn Cũ hơn