cbam代码

CBAM是一种用于增强卷积神经网络性能的注意力机制。

python
import torch import torch.nn as nn import torch.nn.functional as F class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=(kernel_size-1)//2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return self.sigmoid(x) class CBAM(nn.Module): def __init__(self, in_channels): super(CBAM, self).__init__() self.channel_attention = ChannelAttention(in_channels) self.spatial_attention = SpatialAttention() def forward(self, x): x_c = self.channel_attention(x) * x x_s = self.spatial_attention(x) * x return x_c + x_s # 使用示例 class YourModel(nn.Module): def __init__(self): super(YourModel, self).__init__() # 在你的模型中使用CBAM模块 self.cbam = CBAM(in_channels=64) # 例如,64是输入通道数 # 其他模型定义... def forward(self, x): # 其他前向传播逻辑... x = self.cbam(x) return x # 创建模型实例 model = YourModel()

上述代码是一个简化版本,实际的CBAM模块可能会根据具体的需求进行调整。此代码使用PyTorch构建CBAM模块,并在示例模型中使用了它。确保在你的项目中根据实际需要进行调整和集成。

如果你需要更详细的 CBAM 模块的实现,可以使用

python
import torch import torch.nn as nn import torch.nn.functional as F class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=(kernel_size-1)//2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return self.sigmoid(x) class CBAMModule(nn.Module): def __init__(self, in_channels, ratio=16, kernel_size=7): super(CBAMModule, self).__init__() self.channel_attention = ChannelAttention(in_channels, ratio) self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x): x_c = self.channel_attention(x) * x x_s = self.spatial_attention(x) * x return x_c + x_s class SimpleCNNWithCBAM(nn.Module): def __init__(self, num_classes=10): super(SimpleCNNWithCBAM, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.cbam1 = CBAMModule(64) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.cbam2 = CBAMModule(128) self.pool2 = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(128 * 8 * 8, 512) self.fc2 = nn.Linear(512, num_classes) def forward(self, x): x = F.relu(self.conv1(x)) x = self.cbam1(x) x = self.pool1(x) x = F.relu(self.conv2(x)) x = self.cbam2(x) x = self.pool2(x) x = x.view(-1, 128 * 8 * 8) x = F.relu(self.fc1(x)) x = self.fc2(x) return x # 创建模型实例 model = SimpleCNNWithCBAM()

这个示例中,SimpleCNNWithCBAM 是一个简单的卷积神经网络,其中使用了两个 CBAM 模块,它们分别应用于第一个和第二个卷积层。你可以根据你的需求调整模型结构和 CBAM 模块的超参数。