feat: Flask Web UI — 在线cDNA图像处理平台

- 上传图像 + 实时处理 + 6张结果可视化
- 实验室仪器风格深色主题
- 参数统计面板(T/pct/网格/斑点数)
- 图片点击放大 + 单张/全部下载
This commit is contained in:
2026-05-08 11:26:02 +08:00
parent 862d02dce6
commit b07e7a1182
3 changed files with 601 additions and 0 deletions
+211
View File
@@ -0,0 +1,211 @@
"""
cDNA微阵列图像处理 - Web UI (Flask)
=====================================
启动:python web/app.py
打开:http://localhost:5000
"""
import os, sys, io, base64
from flask import Flask, render_template, request, jsonify
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage import color
from scipy import ndimage
# 项目根目录加到 sys.path
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(BASE_DIR, 'src'))
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB
UPLOAD_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
os.makedirs(UPLOAD_DIR, exist_ok=True)
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False
# ================================================================
# 图像处理函数(同简化版逻辑)
# ================================================================
def otsu_threshold_pixels(gray):
best_T, best_cost, total = 0, float('inf'), gray.size
for T in range(1, 255):
bg, fg = gray[gray <= T], gray[gray > T]
if len(bg) == 0 or len(fg) == 0:
continue
cost = len(bg)/total*np.var(bg) + len(fg)/total*np.var(fg)
if cost < best_cost:
best_cost, best_T = cost, T
return best_T
def draw_grid_lines(gray):
T = otsu_threshold_pixels(gray)
pct = T / 255.0
H, W = gray.shape
col_prof = np.sum(gray, axis=0).astype(float)
row_prof = np.sum(gray, axis=1).astype(float)
col_T = (np.max(col_prof)-np.min(col_prof))*pct
row_T = (np.max(row_prof)-np.min(row_prof))*pct
col_s, row_s = col_prof-col_T, row_prof-row_T
def find_gap_lines(prof):
is_pos = prof > 0
crossings = [i for i in range(1, len(is_pos)) if is_pos[i] != is_pos[i-1]]
if len(crossings) < 2:
return np.array([])
start = 1 if not is_pos[0] else 0
return np.array([int((crossings[k]+crossings[k+1])/2) for k in range(start, len(crossings)-1, 2)])
xl = find_gap_lines(col_s)
yl = find_gap_lines(row_s)
return xl, yl, T, pct, col_prof, row_prof, col_s, row_s, col_T, row_T
def keep_largest_object(binary):
L, n = ndimage.label(binary)
if n == 0: return np.zeros_like(binary)
return (L == (np.argmax([int(np.sum(L==i)) for i in range(1,n+1)])+1)).astype(np.uint8)
def remove_small_objects(binary):
L, n = ndimage.label(binary)
if n == 0: return binary
areas = [int(np.sum(L==i)) for i in range(1,n+1)]
minsz = max(1, int(np.median(areas)*0.25))
r = binary.copy()
for i in range(1, n+1):
if areas[i-1] < minsz: r[L==i] = 0
return r
def fig_to_base64(fig):
"""matplotlib figure → base64 PNG"""
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=120, bbox_inches='tight')
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode()
plt.close(fig)
return f'data:image/png;base64,{b64}'
def process_image(img_array):
"""对上传的图像运行完整处理流程,返回 dict"""
# 转灰度
if img_array.ndim == 3 and img_array.shape[2] >= 3:
gray = (color.rgb2gray(img_array[:,:,:3])*255).astype(np.uint8)
else:
gray = img_array.astype(np.uint8)
# 网格划线
xl, yl, T, pct, cp, rp, cs, rs, cT, rT = draw_grid_lines(gray)
# 逐格分割
bw = np.zeros_like(gray)
for i in range(len(yl)-1):
for j in range(len(xl)-1):
r1, r2 = yl[i], yl[i+1]
c1, c2 = xl[j], xl[j+1]
blk = gray[r1:r2, c1:c2]
if blk.size == 0: continue
bt = otsu_threshold_pixels(blk)
bb = keep_largest_object((blk > bt).astype(np.uint8))
bw[r1:r2, c1:c2] = bb
bw_clean = remove_small_objects(bw)
# 统计
L, n = ndimage.label(bw_clean)
spots = [int(np.sum(L==i)) for i in range(1,n+1)]
valid = [s for s in spots if s >= 10]
# ---- 生成6张图 ----
images = {}
# 1: grid overlay
fig, ax = plt.subplots(figsize=(6,6))
ax.imshow(gray, cmap='gray')
for x in xl: ax.axvline(x=x, color='lime', linewidth=0.5)
for y in yl: ax.axhline(y=y, color='lime', linewidth=0.5)
ax.set_title(f'Grid ({len(xl)}x{len(yl)})', fontsize=12); ax.axis('off')
images['grid_overlay'] = fig_to_base64(fig)
# 2: col projection
fig, ax = plt.subplots(figsize=(10,4))
xs = np.arange(len(cp)); ax.plot(xs,cp,'b-',lw=0.6); ax.axhline(y=cT,color='orange',ls='--',lw=1)
ax.plot(xs,cs,'g-',lw=0.6,alpha=0.5)
ax.fill_between(xs,0,cs,where=(cs>0),color='green',alpha=0.1)
ax.fill_between(xs,0,cs,where=(cs<0),color='red',alpha=0.1)
for x in xl: ax.axvline(x=x,color='red',lw=0.5,alpha=0.5)
ax.set_title('Column Projection', fontsize=12); ax.set_xlabel('column')
images['col_projection'] = fig_to_base64(fig)
# 3: row projection
fig, ax = plt.subplots(figsize=(10,4))
ys = np.arange(len(rp)); ax.plot(rp,ys,'b-',lw=0.6); ax.axvline(x=rT,color='orange',ls='--',lw=1)
ax.plot(rs,ys,'g-',lw=0.6,alpha=0.5)
for y in yl: ax.axhline(y=y,color='red',lw=0.5,alpha=0.5)
ax.set_title('Row Projection', fontsize=12); ax.set_ylabel('row')
images['row_projection'] = fig_to_base64(fig)
# 4: histogram
fig, ax = plt.subplots(figsize=(7,4))
ax.hist(gray.ravel(),bins=50,color='#2d8a4e',edgecolor='white',linewidth=0.3)
ax.axvline(x=T,color='#ff4444',ls='--',lw=2,label=f'Otsu T={T} ({pct*100:.1f}%)')
ax.set_title('Histogram + Otsu', fontsize=12); ax.legend()
images['histogram'] = fig_to_base64(fig)
# 5: segmentation raw
fig, ax = plt.subplots(figsize=(6,6))
ax.imshow(bw, cmap='gray'); ax.set_title('Segmentation (raw)', fontsize=12); ax.axis('off')
images['segmentation_raw'] = fig_to_base64(fig)
# 6: post processed
fig, ax = plt.subplots(figsize=(6,6))
ax.imshow(bw_clean, cmap='gray')
ax.set_title(f'Post-processed ({len(valid)} spots)', fontsize=12); ax.axis('off')
images['post_processed'] = fig_to_base64(fig)
stats = {
'spots': len(valid),
'T_otsu': int(T),
'pct': round(pct*100, 1),
'lines_x': int(len(xl)),
'lines_y': int(len(yl)),
'width': int(gray.shape[1]),
'height': int(gray.shape[0])
}
return {'images': images, 'stats': stats}
# ================================================================
# Flask 路由
# ================================================================
@app.route('/')
def index():
return render_template('index.html')
@app.route('/process', methods=['POST'])
def process():
if 'file' not in request.files:
return jsonify({'error': '未找到文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '文件名为空'}), 400
# 读取图像
img_bytes = file.read()
img = Image.open(io.BytesIO(img_bytes))
img_array = np.array(img)
# 处理
try:
result = process_image(img_array)
return jsonify(result)
except Exception as e:
return jsonify({'error': f'处理失败: {str(e)}'}), 500
if __name__ == '__main__':
app.run(debug=True, port=5000)