""" cDNA微阵列图像处理 - Web UI (Flask) ===================================== 启动:python web/app.py 打开:http://localhost:5000 """ import os, sys, io, base64 from flask import Flask, render_template, request, jsonify import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np from PIL import Image from skimage import color from scipy import ndimage # 项目根目录加到 sys.path BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.join(BASE_DIR, 'src')) app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB UPLOAD_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') os.makedirs(UPLOAD_DIR, exist_ok=True) plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] plt.rcParams['axes.unicode_minus'] = False # ================================================================ # 图像处理函数(同简化版逻辑) # ================================================================ def otsu_threshold_pixels(gray): best_T, best_cost, total = 0, float('inf'), gray.size for T in range(1, 255): bg, fg = gray[gray <= T], gray[gray > T] if len(bg) == 0 or len(fg) == 0: continue cost = len(bg)/total*np.var(bg) + len(fg)/total*np.var(fg) if cost < best_cost: best_cost, best_T = cost, T return best_T def draw_grid_lines(gray): T = otsu_threshold_pixels(gray) pct = T / 255.0 H, W = gray.shape col_prof = np.sum(gray, axis=0).astype(float) row_prof = np.sum(gray, axis=1).astype(float) col_T = (np.max(col_prof)-np.min(col_prof))*pct row_T = (np.max(row_prof)-np.min(row_prof))*pct col_s, row_s = col_prof-col_T, row_prof-row_T def find_gap_lines(prof): is_pos = prof > 0 crossings = [i for i in range(1, len(is_pos)) if is_pos[i] != is_pos[i-1]] if len(crossings) < 2: return np.array([]) start = 1 if not is_pos[0] else 0 return np.array([int((crossings[k]+crossings[k+1])/2) for k in range(start, len(crossings)-1, 2)]) xl = find_gap_lines(col_s) yl = find_gap_lines(row_s) return xl, yl, T, pct, col_prof, row_prof, col_s, row_s, col_T, row_T def keep_largest_object(binary): L, n = ndimage.label(binary) if n == 0: return np.zeros_like(binary) return (L == (np.argmax([int(np.sum(L==i)) for i in range(1,n+1)])+1)).astype(np.uint8) def remove_small_objects(binary): L, n = ndimage.label(binary) if n == 0: return binary areas = np.array([int(np.sum(L==i)) for i in range(1,n+1)]) if len(areas) < 2: return binary best_T, best_cost, n_total = 0, float('inf'), len(areas) for T in np.unique(areas): s, l = areas[areas<=T], areas[areas>T] w_s, w_l = len(s)/n_total, len(l)/n_total if w_s==0 or w_l==0: continue cost = w_s*np.var(s) + w_l*np.var(l) if cost < best_cost: best_cost, best_T = cost, T r = binary.copy() for i in range(1, n+1): if int(np.sum(L==i)) < best_T: r[L==i] = 0 return r def fig_to_base64(fig): """matplotlib figure → base64 PNG""" buf = io.BytesIO() fig.savefig(buf, format='png', dpi=120, bbox_inches='tight') buf.seek(0) b64 = base64.b64encode(buf.read()).decode() plt.close(fig) return f'data:image/png;base64,{b64}' def process_image(img_array): """对上传的图像运行完整处理流程,返回 dict""" # 转灰度 if img_array.ndim == 3 and img_array.shape[2] >= 3: gray = (color.rgb2gray(img_array[:,:,:3])*255).astype(np.uint8) else: gray = img_array.astype(np.uint8) # 网格划线 xl, yl, T, pct, cp, rp, cs, rs, cT, rT = draw_grid_lines(gray) # 逐格分割 bw = np.zeros_like(gray) for i in range(len(yl)-1): for j in range(len(xl)-1): r1, r2 = yl[i], yl[i+1] c1, c2 = xl[j], xl[j+1] blk = gray[r1:r2, c1:c2] if blk.size == 0: continue bt = otsu_threshold_pixels(blk) bb = keep_largest_object((blk > bt).astype(np.uint8)) bw[r1:r2, c1:c2] = bb bw_clean = remove_small_objects(bw) # 统计 L, n = ndimage.label(bw_clean) spots = [int(np.sum(L==i)) for i in range(1,n+1)] valid = [s for s in spots if s >= 10] # ---- 生成6张图 ---- images = {} # 1: grid overlay fig, ax = plt.subplots(figsize=(6,6)) ax.imshow(gray, cmap='gray') for x in xl: ax.axvline(x=x, color='lime', linewidth=0.5) for y in yl: ax.axhline(y=y, color='lime', linewidth=0.5) ax.set_title(f'Grid ({len(xl)}x{len(yl)})', fontsize=12); ax.axis('off') images['grid_overlay'] = fig_to_base64(fig) # 2: col projection fig, ax = plt.subplots(figsize=(10,4)) xs = np.arange(len(cp)); ax.plot(xs,cp,'b-',lw=0.6); ax.axhline(y=cT,color='orange',ls='--',lw=1) ax.plot(xs,cs,'g-',lw=0.6,alpha=0.5) ax.fill_between(xs,0,cs,where=(cs>0),color='green',alpha=0.1) ax.fill_between(xs,0,cs,where=(cs<0),color='red',alpha=0.1) for x in xl: ax.axvline(x=x,color='red',lw=0.5,alpha=0.5) ax.set_title('Column Projection', fontsize=12); ax.set_xlabel('column') images['col_projection'] = fig_to_base64(fig) # 3: row projection fig, ax = plt.subplots(figsize=(10,4)) ys = np.arange(len(rp)); ax.plot(rp,ys,'b-',lw=0.6); ax.axvline(x=rT,color='orange',ls='--',lw=1) ax.plot(rs,ys,'g-',lw=0.6,alpha=0.5) for y in yl: ax.axhline(y=y,color='red',lw=0.5,alpha=0.5) ax.set_title('Row Projection', fontsize=12); ax.set_ylabel('row') images['row_projection'] = fig_to_base64(fig) # 4: histogram fig, ax = plt.subplots(figsize=(7,4)) ax.hist(gray.ravel(),bins=50,color='#2d8a4e',edgecolor='white',linewidth=0.3) ax.axvline(x=T,color='#ff4444',ls='--',lw=2,label=f'Otsu T={T} ({pct*100:.1f}%)') ax.set_title('Histogram + Otsu', fontsize=12); ax.legend() images['histogram'] = fig_to_base64(fig) # 5: segmentation raw fig, ax = plt.subplots(figsize=(6,6)) ax.imshow(bw, cmap='gray'); ax.set_title('Segmentation (raw)', fontsize=12); ax.axis('off') images['segmentation_raw'] = fig_to_base64(fig) # 6: post processed fig, ax = plt.subplots(figsize=(6,6)) ax.imshow(bw_clean, cmap='gray') ax.set_title(f'Post-processed ({len(valid)} spots)', fontsize=12); ax.axis('off') images['post_processed'] = fig_to_base64(fig) stats = { 'spots': len(valid), 'T_otsu': int(T), 'pct': round(pct*100, 1), 'lines_x': int(len(xl)), 'lines_y': int(len(yl)), 'width': int(gray.shape[1]), 'height': int(gray.shape[0]) } return {'images': images, 'stats': stats} # ================================================================ # Flask 路由 # ================================================================ @app.route('/') def index(): return render_template('index.html') @app.route('/process', methods=['POST']) def process(): if 'file' not in request.files: return jsonify({'error': '未找到文件'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': '文件名为空'}), 400 # 读取图像 img_bytes = file.read() img = Image.open(io.BytesIO(img_bytes)) img_array = np.array(img) # 处理 try: result = process_image(img_array) return jsonify(result) except Exception as e: return jsonify({'error': f'处理失败: {str(e)}'}), 500 if __name__ == '__main__': import threading def open_browser(): os.startfile('http://localhost:5000') threading.Timer(1.5, open_browser).start() app.run(debug=True, port=5000)