Files
rl-atari/强化学习个人课程作业报告/run_notebook.py
T
Serendipity ceddbdd559 Add lecture materials for Model-Free, Control, and Value topics
- Added Lecture4 - ModelFree.pdf (3013 KB)
- Added Lecture5 - Control.pdf (2575 KB)
- Added Lecture6 - Value.pdf (3320 KB)
2026-04-28 20:28:00 +08:00

66 lines
2.3 KiB
Python

"""
运行 insurance_premium_risk.ipynb 的脚本
将 notebook 代码单元格提取出来逐个执行
"""
import json, sys, os, warnings, traceback, time
warnings.filterwarnings('ignore')
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as _real_mpl_plt
_real_mpl_plt.show = lambda *a, **kw: None
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
import xgboost as xgb
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
sns.set_style('whitegrid')
# ===== 读取 notebook =====
nb_path = r'd:\Code\doing_exercises\programs\外教作业外快\强化学习个人课程作业报告\notebooks\insurance_premium_risk.ipynb'
cells = json.load(open(nb_path, encoding='utf-8'))['cells']
code_cells = [c for c in cells if c['cell_type'] == 'code']
print(f"Total code cells: {len(code_cells)}")
# ===== 执行每个单元格 =====
# 使用全局 __main__ 命名空间,变量跨单元格持久化
main_ns = globals().copy()
for i, cell in enumerate(code_cells):
src = ''.join(cell['source'])
print(f"\n{'='*60}")
print(f"Running cell {i+1}/{len(code_cells)}...")
print(f" Source: {src[:80].replace(chr(10), ' ')}")
try:
exec(compile(src, f'cell_{i+1}', 'exec'), main_ns)
except Exception as e:
print(f"ERROR in cell {i+1}: {e}")
traceback.print_exc()
print("Stopping execution.")
break
print("\n\nAll cells executed successfully!")
print(f"Results saved to: outputs/figures/ and outputs/tables/")