import torch
import torch.nn as nn
# Channel Attention
class ChannelAttention(nn.Module) ## kế thừa PyTorch
def __init__(self, soKenh, tile_soKenhGiam=16):
super(ChannelAttention, self).__init__()
# Gộp Max pooling và Trung Bình pooling về 1x1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
#Tạo mạng tích chập
self.mlp = nn.Sequential(
nn.Linear(in_features=soKenh, out_features=soKenh // tile_soKenhGiam, bias=False),
nn.ReLU(),
nn.Linear(in_features=soKenh // tile_soKenhGiam, out_features=soKenh, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.mlp(self.avg_pool(x)) #lấy tb pooling -> qua 2 lớp FC
max_out = self.mlp(self.max_pool(x)) #lấy max pooling -> qua 2 lớp FC
# Cộng hai kết quả lại và đi qua hàm Sigmoid
out = avg_out + max_out
return self.sigmoid(out)
# Class Spatial Attention
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'Kernel size phải là 3 hoặc 7'
padding = 3 if kernel_size == 7 else 1 # thêm padding
# Tích chập để gộp thông tin không gian
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Lấy giá trị trung bình và lớn nhất dọc theo trục channel
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
# Nối (concatenate) hai ma trận lại với nhau
x_cat = torch.cat([avg_out, max_out], dim=1)
# Đi qua lớp tích chập và hàm Sigmoid
out = self.conv1(x_cat)
return self.sigmoid(out)
class CBAM(nn.Module):
def __init__(self, soKenh, tile_soKenhGiam=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(soKenh, tile_soKenhGiam)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
# Nhân input với trọng số của Channel Attention
x = x * self.channel_attention(x)
# Nhân kết quả tiếp với trọng số của Spatial Attention
x = x * self.spatial_attention(x)
return x
class CNN_CBAM(nn.Module):
def __init__(self, num_classes=10):
super(CNN_CBAM, self).__init__()
# Lớp Convolution đầu tiên
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
# Tích hợp CBAM ngay sau khi trích xuất đặc trưng
self.cbam = CBAM(in_planes=64)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Classifier
# Giả sử ảnh đầu vào là 32x32. Sau MaxPool 2x2 thì kích thước còn 16x16
self.fc = nn.Linear(64 * 16 * 16, num_classes)
# luồng hoạt động CNN - CBAM
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
# Đi qua khối Attention
x = self.cbam(x)
x = self.pool(x)
# Làm phẳng tensor để đưa vào lớp Linear
x = torch.flatten(x, 1)
x = self.fc(x)
return x
Đăng nhận xét