UNet进行病理图像分割

数据集链接:https://pan.baidu.com/s/1IBe_P0AyHgZC39NqzOxZhA?pwd=nztc
提取码:nztc

  • UNet模型
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.up(x)
        return x
class UNet(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(UNet, self).__init__()
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)
        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)
        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)
        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)
        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)
        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        d1 = self.Conv_1x1(d2)
        output = torch.sigmoid(d1)  # 在最后加上Sigmoid激活函数
        return output
  • 数据加载
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, output_size=(256, 256)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_list = os.listdir(image_dir)
        self.output_size = output_size
        # 定义图像和掩码的变换
        self.image_transform = transforms.Compose([
            transforms.Resize(self.output_size),
            transforms.ToTensor()
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize(self.output_size),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = self.image_list[idx]
        image_path = os.path.join(self.image_dir, image_name)
        mask_path = os.path.join(self.mask_dir, image_name)
        image = Image.open(image_path).convert("RGB")  # 确保是RGB
        mask = Image.open(mask_path).convert("L")  # 确保是灰度图像
        image = self.image_transform(image)
        mask = self.mask_transform(mask)
        return image, mask
  • 训练和测试。训练函数中保存的最好模型后缀最大(因为loss小才保存当前这个epoch的模型,我训练的最好模型是第171轮产生的),测试代码包含计算模型性能指标的代码和保存结果图片的代码。
import os
import numpy as np
import torch
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from UNet import UNet
from DataLoader2 import SegmentationDataset

# IoU计算
def compute_iou(pred_mask, true_mask):
    smooth = 1e-6  # 避免分母为0
    pred_mask = (pred_mask > 0.5).float()
    true_mask = (true_mask > 0.5).float()

    intersection = (pred_mask * true_mask).sum()
    union = pred_mask.sum() + true_mask.sum() - intersection

    return (intersection + smooth) / (union + smooth)

# Dice系数计算
def compute_dice(pred_mask, true_mask):
    smooth = 1e-6  # 避免分母为0
    pred_mask = (pred_mask > 0.5).float()
    true_mask = (true_mask > 0.5).float()

    intersection = (pred_mask * true_mask).sum()

    return (2. * intersection + smooth) / (pred_mask.sum() + true_mask.sum() + smooth)

# 精度、召回率和F1分数计算
def compute_precision_recall_f1(pred_mask, true_mask):
    pred_mask = (pred_mask > 0.5).numpy().astype(int)
    true_mask = (true_mask > 0.5).numpy().astype(int)

    # 将mask平展为一维数组
    pred_mask_flat = pred_mask.flatten()
    true_mask_flat = true_mask.flatten()

    conf_matrix = confusion_matrix(true_mask_flat, pred_mask_flat)
    tn, fp, fn, tp = conf_matrix.ravel()

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1_score = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1_score


# 训练函数
def train():
    model = UNet()
    dataset = SegmentationDataset('./dataset_exp2/train/image', './dataset_exp2/train/label')
    dataloader = DataLoader(batch_size=16, shuffle=True, dataset=dataset)
    # 训练参数
    num_epochs = 200
    learning_rate = 1e-4
    # 损失函数和优化器
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # 设备
    device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.train()
    best_loss = float('inf')
    for epoch in range(num_epochs):
        epoch_loss = 0
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), f'./save_model_UNet/res_{epoch + 1}.pth')
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(dataloader)}')


def test():
    model = UNet()
    # 确保模型在CPU上
    model.load_state_dict(torch.load('./save_model_UNet/res_171.pth'))
    save_dir = './test_results_UNet'
    model.eval()
    dataset = SegmentationDataset('./dataset_exp2/test/image', './dataset_exp2/test/label')
    dataloader = DataLoader(batch_size=1, shuffle=False, dataset=dataset)
    iou_list = []
    dice_list = []
    precision_list = []
    recall_list = []
    f1_list = []
    plt.ion()
    with torch.no_grad():
        for idx, (images, labels) in tqdm(enumerate(dataloader)):
            pre = model(images)
            img_pre = torch.squeeze(pre)
            img_true = torch.squeeze(labels)
            iou = compute_iou(img_pre, img_true)
            dice = compute_dice(img_pre, img_true)
            precision, recall, f1_score = compute_precision_recall_f1(img_pre, img_true)
            img_pre = img_pre.numpy()
            img_true = img_true.numpy()
            img_x = torch.squeeze(images).numpy().transpose(1, 2, 0)
            img_x = (img_x * 255).astype(np.uint8)  # 恢复到0-255的范围
            # 保存结果
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.title('Input Image')
            plt.imshow(img_x)
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.title('True Mask')
            plt.imshow(img_true, cmap='gray')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.title('UNet Predicted Mask')
            plt.imshow(img_pre, cmap='gray')
            plt.axis('off')

            plt.savefig(os.path.join(save_dir, f'result_{idx + 1}.png'))
            plt.close()  # 关闭当前figure,避免内存占用过多

            iou_list.append(iou.item())
            dice_list.append(dice.item())
            precision_list.append(precision)
            recall_list.append(recall)
            f1_list.append(f1_score)

        plt.ioff()  # 关闭交互模式
        print(f'Results saved in {save_dir}')
        print(f'Average IoU: {np.mean(iou_list)}')
        print(f'Average Dice Coefficient: {np.mean(dice_list)}')
        print(f'Average Precision: {np.mean(precision_list)}')
        print(f'Average Recall: {np.mean(recall_list)}')
        print(f'Average F1 Score: {np.mean(f1_list)}')

if __name__ == '__main__':
    print('++++++++++++++++train++++++++++++++++')
    train()
    print('++++++++++++++++test++++++++++++++++')
    test()

测试效果:
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/763626.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JVM原理(十):JVM虚拟机调优分析与实战

1. 大内存硬件上的程序部署策略 这是笔者很久之前处理过的一个案例&#xff0c;但今天仍然具有代表性。一个15万PV/日左右的在线文档类型网站最近更换了硬件系统&#xff0c;服务器的硬件为四路志强处理器、16GB物理内存&#xff0c;操作系统为64位CentOS5.4&#xff0c;Resin…

Android Studio 解决AAPT: error: file failed to compile

1.找到项目下的build.gradle 2.在android语块中添加下面代码 aaptOptions.cruncherEnabled false aaptOptions.useNewCruncher false 12

Linux中的库

什么是库&#xff1f; 库是一组预先编译好的方法/函数的集合&#xff0c;其他程序想要使用源文件中的函数时&#xff0c;只需在编译可执行程序时&#xff0c;链接上该源文件生成的库文件即可。 库分为两类&#xff1a;静态库和动态库 在Linux系统中&#xff0c;以.a为后缀的…

力扣热100 哈希

哈希 1. 两数之和49.字母异位词分组128.最长连续序列 1. 两数之和 题目&#xff1a;给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。你可以假设每种输入只会对应一个答案。…

【NOI-题解】1326. 需要安排几位师傅加工零件1228. 排队打水问题1229. 拦截导弹的系统数量求解

文章目录 一、前言二、问题问题&#xff1a;1326. 需要安排几位师傅加工零件问题&#xff1a;1228. 排队打水问题问题&#xff1a;1229. 拦截导弹的系统数量求解 三、感谢 一、前言 本章节主要对贪心问题进行讲解&#xff0c;包括《1326. 需要安排几位师傅加工零件》《1228. 排…

每天五分钟深度学习:解决for循环效率慢的关键在于向量化

本文重点 上一节课程中,我们学习了多样本的线性回归模型,但是我们的伪代码实现中使用了大量的for循环,这样代码的问题是效率很低。为了克服这一瓶颈,向量化技术应运而生,成为提升程序执行效率、加速数据处理速度的重要手段。 向量化技术概述 向量化(Vectorization)是…

目标检测算法讲解:从传统方法到深度学习,全面解析检测技术的演进与应用!

在计算机视觉领域&#xff0c;目标检测是一个基本且关键的任务&#xff0c;它不仅涉及图像中对象的识别&#xff0c;还包括确定这些对象的具体位置。这一任务通常通过算法来实现&#xff0c;这些算法能够识别出图像中的一个或多个目标&#xff0c;并给出每个目标的类别和位置。…

Kafka-服务端-网络层-源码流程

整体架构如下所示&#xff1a; responseQueue不在RequestChannel中&#xff0c;在Processor中&#xff0c;每个Processor内部有一个responseQueue 客户端发送的请求被Acceptor转发给Processor处理处理器将请求放到RequestChannel的requestQueue中KafkaRequestHandler取出reque…

Python:Python简介

一、Python简介 1.Python的诞生 诞生&#xff1a;1989年圣诞节期间&#xff0c;Guido van Rossum为了打发圣诞节假期的无聊&#xff0c;便开始了Python语言的编写。 命名&#xff1a;Python第一个发行版本是在1991年&#xff0c;起名为Python是源自于Guido喜欢的一档电视节目…

英伟达经济学:云服务商在GPU上每花1美元 就能赚7美元

NVIDIA超大规模和 HPC 业务副总裁兼总经理 Ian Buck 近日在美国银行证券 2024 年全球技术大会上表示&#xff0c;客户正在投资数十亿美元购买新的NVIDIA硬件&#xff0c;以跟上更新的 AI 大模型的需求&#xff0c;从而提高收入和生产力。 Buck表示&#xff0c;竞相建设大型数据…

在 PostgreSQL 中强制执行连接顺序#postgresql认证

让我们首先创建一些表&#xff1a; PgSQL plan# SELECT CREATE TABLE x || id || (id int) FROM generate_series(1, 5) AS id;?column? --------------------------CREATE TABLE x1 (id int)CREATE TABLE x2 (id int)CREATE TABLE x3 (id int)CREATE TABLE…

Centos7网络配置(设置固定ip)

文章目录 1进入虚拟机设置选中【网络适配器】选择【NAT模式】2 进入windows【控制面板\网络和 Internet\网络和共享中心\更改适配器设置】设置网络状态。3 设置VM的【虚拟网络编辑器】4 设置系统网卡5 设置虚拟机固定IP 刚安装完系统&#xff0c;有的人尤其没有勾选自动网络配置…

解锁机器学习算法面试挑战课程

在这个课程中&#xff0c;我们将从基础知识出发&#xff0c;系统学习机器学习与算法的核心概念和实践技巧。通过大量案例分析和LeetCode算法题解&#xff0c;帮助您深入理解各种面试问题&#xff0c;并掌握解题技巧和面试技巧。无论是百面挑战还是LeetCode算法题&#xff0c;都…

VUE3解决跨域问题

本文基于vue3 vite element-plus pnpm 报错&#xff1a;**** has been blocked by CORS policy: No Access-Control-Allow-Origin header is present on the requested resource. 原因&#xff1a;前端不能直接访问其他IP&#xff0c;需要用vite.config.ts &#xff0…

仿美团饿了么程序,外卖人9.0商业版外卖订餐源码(PC+微信)

仿美团饿了么程序,外卖人9.0外卖订餐源码,PC微信WAP短信宝,多城市多色版 非常不错的独立版外卖跑腿网站源码&#xff0c;喜欢的可以下载调试看看吧&#xff01;&#xff01; 仿美团饿了么程序,外卖人9.0外卖订餐源码

鸿蒙开发Ability Kit(程序访问控制):【向用户申请单次授权】

申请使用受限权限 受限开放的权限通常是不允许三方应用申请的。当应用在申请权限来访问必要的资源时&#xff0c;发现部分权限的等级比应用APL等级高&#xff0c;开发者可以选择通过ACL方式来解决等级不匹配的问题&#xff0c;从而使用受限权限。 举例说明&#xff0c;如果应…

【面试干货】Static关键字的用法详解

【面试干货】Static关键字的用法详解 1、Static修饰内部类2、Static修饰方法3、Static修饰变量4、Static修饰代码块5、总结 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在Java编程语言中&#xff0c;static是一个关键字&#xff0c;它可…

【多模态LLM】以ViT进行视觉表征的多模态模型1(BLIP-2、InstructBLIP)

note CLIP和BLIP的区别&#xff1a; CLIP&#xff1a;通过对比学习联合训练&#xff0c;预测图像和文本之间的匹配关系。即使用双塔结构&#xff0c;分别对图像和文本编码&#xff0c;然后通过计算cos进行图文匹配。BLIP&#xff1a;包括两个单模态编码器&#xff08;图像编码…

【TB作品】温湿度监控系统设计,ATMEGA16单片机,Proteus仿真

题2:温湿度监控系统设计 功能要求: 1)开机显示时间(小时、分)、时分可修改; 2)用两个滑动变阻器分别模拟温度传感器(测量范 围0-100度)与湿度传感器(0-100%),通过按键 可以在数码管切换显示当前温度值、湿度值; 3)当温度低于20度时,红灯长亮; 4)当湿度高于70%时,黄灯长亮; 5)当…

win11自动删除文件的问题,安全中心提示

win11自动删除文件的问题&#xff0c;解决方法&#xff1a; 1.点击任务栏上的开始图标&#xff0c;在显示的应用中&#xff0c;点击打开设置。 或者点击电脑右下角的开始也可以 2.点击设置。也可以按Wini打开设置窗口。 3.左侧点击隐私和安全性&#xff0c;右侧点击Windows安全…