Files
cDNA-image-processing/src/cDNA_gridding_simple.py
T

369 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
cDNA微阵列图像处理 —— 简化版
======================================
D:\ProgramData\anaconda3\envs\my_env\python.exe src/cDNA_gridding_simple.py
一、算法流程总览
灰度图 ──→ Otsu求像素最佳阈值 T ──→ 百分比 = T/255(自适应)
├─→ 投影/减阈值/过零点配对 ──→ 网格线
├─→ 逐格 Otsu 分割 ──→ keep_largest_object(每格留最大块)
└─→ remove_small_objects(中位数25%以下判为噪声)──→ 统计斑点数
二、各步骤详解
1. 彩色图 → 灰度图
2. Otsu 自动阈值
遍历灰度 0~255,每个候选 T 将像素分为前景(>T)和背景(≤T),
计算类内方差 w_bg×σ²_bg + w_fg×σ²_fg,选使方差最小的 T。
3. 投影
横轴:np.sum(每列) → 曲线,高点=斑点列,低点=空隙列
纵轴:np.sum(每行) → 曲线,高点=斑点行,低点=空隙行
4. 阈值 X = (max-min) × (T/255)
5. 曲线减 X → 大于 0 = 斑点区域,小于 0 = 空隙
过零点 = 斑点和空隙的分界线
6. 过零点配对
过零点交替:正→负(离开斑点)、负→正(进入下一斑点)
配对「离开斑点 + 进入下一斑点」,中点 = 空隙中央 = 划线位置
7. 逐格分割 + 后处理
对每个格子独立做 Otsu → keep_largest_object(留最大块)
→ 全局 remove_small_objects(自动去噪)→ 统计斑点数
8. 输出三栏图:左=网格,中=分割,右=后处理结果
"""
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage import color
from scipy import ndimage
# matplotlib 中文字体设置
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False
# 路径设置(从脚本位置动态推导,禁止硬编码绝对路径)
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(SCRIPT_DIR)
DATA_DIR = os.path.join(BASE_DIR, 'cDNA图像处理实例', '数据', 'cDNA')
OUTPUT_DIR = os.path.join(BASE_DIR, 'results_simple')
# ================================================================
# 函数1Otsu 像素级阈值
# ================================================================
def otsu_threshold_pixels(gray: np.ndarray) -> int:
"""
对图像像素做 Otsu 自动阈值检测。
遍历灰度值 0~255,对每个候选 T:
- 将像素分为两组:前景(>T) 和 背景(≤T)
- 计算类内方差 = w_bg × σ²_bg + w_fg × σ²_fg
- 选使类内方差最小的 T
返回 T0~255 整数)。
"""
best_T = 0 # 当前最佳阈值
best_cost = float('inf') # 当前最小类内方差
total = gray.size # 总像素数(用于算权重)
for T in range(1, 255):
# 按 T 分组
bg = gray[gray <= T] # 背景像素
fg = gray[gray > T] # 前景像素(斑点)
w_bg = len(bg) / total # 背景占比
w_fg = len(fg) / total # 前景占比
if w_bg == 0 or w_fg == 0:
continue # 某组为空(T 太极端),跳过
# 类内方差 = 加权平均方差
# 方差小 = 组内像素灰度接近 = 分组效果好
cost = w_bg * np.var(bg) + w_fg * np.var(fg)
if cost < best_cost:
best_cost = cost
best_T = T
return best_T
# ================================================================
# 函数2:网格划线
# ================================================================
def draw_grid_lines(gray: np.ndarray):
"""
检测网格分割线。
流程:
Otsu 求自适应百分比 → 列/行投影 → 减阈值 → 过零点配对 → 空隙中点
返回 (纵线, 横线, T, pct, 列投影, 行投影, 减阈值后的列投影, 减阈值后的行投影)
"""
T = otsu_threshold_pixels(gray) # 像素级最佳阈值
pct = T / 255.0 # 自适应百分比
H, W = gray.shape
# ---- 1. 横轴投影 ----
col_profile = np.sum(gray, axis=0).astype(float)
# ---- 2. 纵轴投影 ----
row_profile = np.sum(gray, axis=1).astype(float)
# ---- 3. 投影阈值 ----
col_T_val = (np.max(col_profile) - np.min(col_profile)) * pct
row_T_val = (np.max(row_profile) - np.min(row_profile)) * pct
# ---- 4. 曲线减去阈值 ----
col_shifted = col_profile - col_T_val
row_shifted = row_profile - row_T_val
# ---- 5. 过零点配对 → 空隙中线 ----
def find_gap_lines(prof_shifted: np.ndarray) -> np.ndarray:
"""
在减去阈值后的曲线上,配对过零点,取空隙中央。
原理图解:
信号: ----++++----++++----++++
↑ ↑ ↑ ↑
过零点配对:离开斑点 + 进入下一个斑点
→ 中点 = 空隙中央 = 划线位置
"""
# 每个位置是正(斑点)还是负(空隙)
is_positive = prof_shifted > 0
# 收集符号变化位置(过零点)
crossings = []
for i in range(1, len(is_positive)):
if is_positive[i] != is_positive[i - 1]: # 正负翻转
crossings.append(i)
if len(crossings) < 2: # 过零点不足
return np.array([])
# 过零点交替:正→负(离开斑点), 负→正(进入下一斑点)
# 要配对的是"离开斑点 → 进入下一斑点",即空隙的两端
# 如果信号开头是负,跳过第一个 crossing
start = 1 if not is_positive[0] else 0
lines = []
for k in range(start, len(crossings) - 1, 2):
if k + 1 < len(crossings):
# crossings[k]: 正→负(离开斑点)
# crossings[k+1]: 负→正(进入下一斑点)
# 中点 = 空隙中央 = 划线位置
mid = int((crossings[k] + crossings[k + 1]) / 2)
lines.append(mid)
return np.array(lines)
x_lines = find_gap_lines(col_shifted)
y_lines = find_gap_lines(row_shifted)
return x_lines, y_lines, T, pct, col_profile, row_profile, col_shifted, row_shifted, col_T_val, row_T_val
# ================================================================
# 函数3:后处理(完全自动,无需人工设定阈值)
# ================================================================
def keep_largest_object(binary: np.ndarray) -> np.ndarray:
"""
每个格子里只保留面积最大的连通域。
ndimage.label 给每个白色连通域编号 → 算面积 → 只留最大那块。
不需要设定任何阈值。
"""
labeled, num = ndimage.label(binary)
if num == 0:
return np.zeros_like(binary) # 全黑,直接返回
# 统计每个连通域的像素数
areas = [int(np.sum(labeled == i)) for i in range(1, num + 1)]
# 找面积最大的编号
max_idx = int(np.argmax(areas)) + 1
return (labeled == max_idx).astype(np.uint8)
# ================================================================
# 函数4:自动去除小连通域(噪声)
# ================================================================
def remove_small_objects(binary: np.ndarray) -> np.ndarray:
"""
自动去除小连通域(噪声)。
对连通域面积分布做 Otsu 阈值检测:
面积分布天然双峰——噪声区(几个像素) 和 真斑点区(几百像素)。
Otsu 自动找到两峰之间的最佳分界,小于该值的视为噪声。
换图换分辨率都自动适应,不需要手动调参。
"""
labeled, num = ndimage.label(binary)
if num == 0:
return binary
# 收集所有连通域的面积
areas = np.array([int(np.sum(labeled == i)) for i in range(1, num + 1)])
if len(areas) < 2:
return binary
# 对面积数组做 Otsu(与像素 Otsu 完全相同的原理)
# 将面积值当作"灰度",找到最小类内方差的分界点
best_T, best_cost, n_total = 0, float('inf'), len(areas)
for T in np.unique(areas):
small = areas[areas <= T] # 候选噪声组
large = areas[areas > T] # 候选真斑点组
w_s = len(small) / n_total
w_l = len(large) / n_total
if w_s == 0 or w_l == 0:
continue
cost = w_s * np.var(small) + w_l * np.var(large)
if cost < best_cost:
best_cost = cost
best_T = T
min_size = best_T # Otsu 自动找到的面积分界线
# 面积不达标的连通域整块置0
result = binary.copy()
for i in range(1, num + 1):
if int(np.sum(labeled == i)) < min_size:
result[labeled == i] = 0
return result
# ================================================================
# 主流程
# ================================================================
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ---- 读取图像,转为灰度 ----
img = np.array(Image.open(os.path.join(DATA_DIR, 'cDNA.png')))
# 原图 RGBA,取前三个通道转为 0~255 灰度图
gray = (color.rgb2gray(img[:, :, :3]) * 255).astype(np.uint8)
# ---- 1. 网格划线 ----
(x_lines, y_lines, T_otsu, pct,
col_prof, row_prof, col_shifted, row_shifted,
col_T_val, row_T_val) = draw_grid_lines(gray)
print(f"检测到 {len(x_lines)} 条纵线, {len(y_lines)} 条横线")
print(f"Otsu 阈值: T={T_otsu}, 自适应百分比: {pct*100:.1f}%")
# ---- 2. 逐格分割 + 后处理 ----
bw_full = np.zeros_like(gray)
for i in range(len(y_lines) - 1):
for j in range(len(x_lines) - 1):
r1, r2 = y_lines[i], y_lines[i + 1]
c1, c2 = x_lines[j], x_lines[j + 1]
blk = gray[r1:r2, c1:c2]
if blk.size == 0:
continue
T = otsu_threshold_pixels(blk)
bw_blk = (blk > T).astype(np.uint8)
bw_blk = keep_largest_object(bw_blk)
bw_full[r1:r2, c1:c2] = bw_blk
bw_clean = remove_small_objects(bw_full)
# ---- 3. 统计斑点 ----
labeled, num = ndimage.label(bw_clean)
spot_sizes = [int(np.sum(labeled == i)) for i in range(1, num + 1)]
valid = [s for s in spot_sizes if s >= 10]
print(f"检测到 {len(valid)} 个斑点")
# ---- 4. 可视化输出(每张图独立保存)----
# 图1:网格线叠加原图
fig1, ax1 = plt.subplots(figsize=(8, 8))
ax1.imshow(gray, cmap='gray')
for x in x_lines:
ax1.axvline(x=x, color='lime', linewidth=0.5)
for y in y_lines:
ax1.axhline(y=y, color='lime', linewidth=0.5)
ax1.set_title(f'网格划分 ({len(x_lines)}x{len(y_lines)})', fontsize=13)
ax1.axis('off')
fig1.savefig(os.path.join(OUTPUT_DIR, '01_grid_overlay.png'), dpi=150, bbox_inches='tight')
plt.close(fig1)
# 图2:列投影曲线(带阈值线和过零点标记)
fig2, ax2 = plt.subplots(figsize=(10, 4))
xs = np.arange(len(col_prof))
ax2.plot(xs, col_prof, 'b-', linewidth=0.6, label='col profile')
ax2.axhline(y=col_T_val, color='orange', linestyle='--', linewidth=1,
label=f'threshold X={col_T_val:.0f}')
ax2.plot(xs, col_shifted, 'g-', linewidth=0.6, alpha=0.5, label='after -X')
ax2.fill_between(xs, 0, col_shifted, where=(col_shifted > 0), color='green', alpha=0.1)
ax2.fill_between(xs, 0, col_shifted, where=(col_shifted < 0), color='red', alpha=0.1)
zero_idx = np.where(np.diff(col_shifted > 0) != 0)[0]
for zi in zero_idx[:50]:
ax2.axvline(x=zi, color='purple', linewidth=0.3, alpha=0.5)
for xl in x_lines:
ax2.axvline(x=xl, color='red', linewidth=0.8, alpha=0.7)
ax2.set_title('col projection', fontsize=12)
ax2.set_xlabel('col')
ax2.legend(fontsize=8)
fig2.savefig(os.path.join(OUTPUT_DIR, '02_col_projection.png'), dpi=120, bbox_inches='tight')
plt.close(fig2)
# 图3:行投影曲线
fig3, ax3 = plt.subplots(figsize=(10, 4))
ys = np.arange(len(row_prof))
ax3.plot(row_prof, ys, 'b-', linewidth=0.6, label='row profile')
ax3.axvline(x=row_T_val, color='orange', linestyle='--', linewidth=1,
label=f'threshold X={row_T_val:.0f}')
ax3.plot(row_shifted, ys, 'g-', linewidth=0.6, alpha=0.5, label='after -X')
ax3.fill_betweenx(ys, 0, row_shifted, where=(row_shifted > 0), color='green', alpha=0.1)
ax3.fill_betweenx(ys, 0, row_shifted, where=(row_shifted < 0), color='red', alpha=0.1)
zero_idx_r = np.where(np.diff(row_shifted > 0) != 0)[0]
for zi in zero_idx_r[:50]:
ax3.axhline(y=zi, color='purple', linewidth=0.3, alpha=0.5)
for yl in y_lines:
ax3.axhline(y=yl, color='red', linewidth=0.8, alpha=0.7)
ax3.set_title('row projection', fontsize=12)
ax3.set_ylabel('row')
ax3.legend(fontsize=8)
fig3.savefig(os.path.join(OUTPUT_DIR, '03_row_projection.png'), dpi=120, bbox_inches='tight')
plt.close(fig3)
# 图4:灰度直方图 + Otsu 阈值
fig4, ax4 = plt.subplots(figsize=(8, 5))
ax4.hist(gray.ravel(), bins=50, color='gray', edgecolor='black', linewidth=0.3)
ax4.axvline(x=T_otsu, color='red', linestyle='--', linewidth=2,
label=f'Otsu T={T_otsu} (pct={pct*100:.1f}%)')
ax4.set_title('histogram + Otsu threshold', fontsize=12)
ax4.set_xlabel('gray value')
ax4.set_ylabel('pixel count')
ax4.legend()
fig4.savefig(os.path.join(OUTPUT_DIR, '04_histogram.png'), dpi=120, bbox_inches='tight')
plt.close(fig4)
# 图5:逐格 Otsu 分割(后处理前)
fig5, ax5 = plt.subplots(figsize=(8, 8))
ax5.imshow(bw_full, cmap='gray')
ax5.set_title('per-cell Otsu (before post-processing)', fontsize=13)
ax5.axis('off')
fig5.savefig(os.path.join(OUTPUT_DIR, '05_segmentation_raw.png'), dpi=150, bbox_inches='tight')
plt.close(fig5)
# 图6:后处理结果(最终二值图)
fig6, ax6 = plt.subplots(figsize=(8, 8))
ax6.imshow(bw_clean, cmap='gray')
ax6.set_title(f'post-processed ({len(valid)} spots)', fontsize=13)
ax6.axis('off')
fig6.savefig(os.path.join(OUTPUT_DIR, '06_post_processed.png'), dpi=150, bbox_inches='tight')
plt.close(fig6)
print(f"共保存 6 张图到: {OUTPUT_DIR}")
if __name__ == '__main__':
main()