feat: 添加强化学习项目报告及重构课程作业报告代码结构
- 新增强化学习个人项目报告,包含基于PyTorch从零实现的PPO算法 - 重构课程作业报告代码结构,提取运行时路径管理和notebook执行逻辑到独立模块 - 更新依赖文件requirements.txt,添加强化学习相关依赖 - 简化模型比较结果表格,仅保留基线逻辑回归模型数据
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
运行 insurance_premium_risk.ipynb 的脚本
|
||||
将 notebook 代码单元格提取出来逐个执行
|
||||
"""
|
||||
import json, sys, os, warnings, traceback, time
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as _real_mpl_plt
|
||||
|
||||
_real_mpl_plt.show = lambda *a, **kw: None
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -32,34 +35,18 @@ import xgboost as xgb
|
||||
import optuna
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
|
||||
RANDOM_STATE = 42
|
||||
np.random.seed(RANDOM_STATE)
|
||||
plt.rcParams['figure.figsize'] = (10, 6)
|
||||
plt.rcParams['font.size'] = 12
|
||||
sns.set_style('whitegrid')
|
||||
from src.notebook_runner import execute_notebook
|
||||
from src.runtime_paths import build_paths
|
||||
|
||||
# ===== 读取 notebook =====
|
||||
nb_path = r'd:\Code\doing_exercises\programs\外教作业外快\强化学习个人课程作业报告\notebooks\insurance_premium_risk.ipynb'
|
||||
cells = json.load(open(nb_path, encoding='utf-8'))['cells']
|
||||
code_cells = [c for c in cells if c['cell_type'] == 'code']
|
||||
print(f"Total code cells: {len(code_cells)}")
|
||||
paths = build_paths()
|
||||
print(f"Project root : {paths.project_root}")
|
||||
print(f"Notebook : {paths.notebook}")
|
||||
print(f"Data dir : {paths.data_dir}")
|
||||
print(f"Output dir : {paths.output_dir}")
|
||||
|
||||
# ===== 执行每个单元格 =====
|
||||
# 使用全局 __main__ 命名空间,变量跨单元格持久化
|
||||
main_ns = globals().copy()
|
||||
ns = vars()
|
||||
|
||||
for i, cell in enumerate(code_cells):
|
||||
src = ''.join(cell['source'])
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running cell {i+1}/{len(code_cells)}...")
|
||||
print(f" Source: {src[:80].replace(chr(10), ' ')}")
|
||||
try:
|
||||
exec(compile(src, f'cell_{i+1}', 'exec'), main_ns)
|
||||
except Exception as e:
|
||||
print(f"ERROR in cell {i+1}: {e}")
|
||||
traceback.print_exc()
|
||||
print("Stopping execution.")
|
||||
break
|
||||
|
||||
print("\n\nAll cells executed successfully!")
|
||||
print(f"Results saved to: outputs/figures/ and outputs/tables/")
|
||||
result = execute_notebook(namespace=ns)
|
||||
print(f"\nExecution finished: {result['status']}")
|
||||
print(f"Cells run: {len([c for c in result['cells'] if c['status'] == 'ok'])}/{result['total']}")
|
||||
print(f"Output dir: {result['outputs']['output_dir']}")
|
||||
Reference in New Issue
Block a user