cbam代码
CBAM是一种用于增强卷积神经网络性能的注意力机制。
pythonimport torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=(kernel_size-1)//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, in_channels):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(in_channels)
self.spatial_attention = SpatialAttention()
def forward(self, x):
x_c = self.channel_attention(x) * x
x_s = self.spatial_attention(x) * x
return x_c + x_s
# 使用示例
class YourModel(nn.Module):
def __init__(self):
super(YourModel, self).__init__()
# 在你的模型中使用CBAM模块
self.cbam = CBAM(in_channels=64) # 例如,64是输入通道数
# 其他模型定义...
def forward(self, x):
# 其他前向传播逻辑...
x = self.cbam(x)
return x
# 创建模型实例
model = YourModel()
上述代码是一个简化版本,实际的CBAM模块可能会根据具体的需求进行调整。此代码使用PyTorch构建CBAM模块,并在示例模型中使用了它。确保在你的项目中根据实际需要进行调整和集成。
如果你需要更详细的 CBAM 模块的实现,可以使用
pythonimport torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=(kernel_size-1)//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)
class CBAMModule(nn.Module):
def __init__(self, in_channels, ratio=16, kernel_size=7):
super(CBAMModule, self).__init__()
self.channel_attention = ChannelAttention(in_channels, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
x_c = self.channel_attention(x) * x
x_s = self.spatial_attention(x) * x
return x_c + x_s
class SimpleCNNWithCBAM(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNNWithCBAM, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.cbam1 = CBAMModule(64)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.cbam2 = CBAMModule(128)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(128 * 8 * 8, 512)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.cbam1(x)
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.cbam2(x)
x = self.pool2(x)
x = x.view(-1, 128 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleCNNWithCBAM()
这个示例中,SimpleCNNWithCBAM
是一个简单的卷积神经网络,其中使用了两个 CBAM 模块,它们分别应用于第一个和第二个卷积层。你可以根据你的需求调整模型结构和 CBAM 模块的超参数。