{
"cells": [
{
"cell_type": "markdown",
"id": "273bffe9",
"metadata": {
"origin_pos": 0
},
"source": [
"# 风格迁移\n",
"\n",
"摄影爱好者也许接触过滤波器。它能改变照片的颜色风格,从而使风景照更加锐利或者令人像更加美白。但一个滤波器通常只能改变照片的某个方面。如果要照片达到理想中的风格,可能需要尝试大量不同的组合。这个过程的复杂程度不亚于模型调参。\n",
"\n",
"本节将介绍如何使用卷积神经网络,自动将一个图像中的风格应用在另一图像之上,即*风格迁移*(style transfer) :cite:`Gatys.Ecker.Bethge.2016`。\n",
"这里我们需要两张输入图像:一张是*内容图像*,另一张是*风格图像*。\n",
"我们将使用神经网络修改内容图像,使其在风格上接近风格图像。\n",
"例如, :numref:`fig_style_transfer`中的内容图像为本书作者在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而风格图像则是一幅主题为秋天橡树的油画。\n",
"最终输出的合成图像应用了风格图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。\n",
"\n",
"\n",
":label:`fig_style_transfer`\n",
"\n",
"## 方法\n",
"\n",
" :numref:`fig_style_transfer_model`用简单的例子阐述了基于卷积神经网络的风格迁移方法。\n",
"首先,我们初始化合成图像,例如将其初始化为内容图像。\n",
"该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数。\n",
"然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。\n",
"这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或风格特征。\n",
"以 :numref:`fig_style_transfer_model`为例,这里选取的预训练的神经网络含有3个卷积层,其中第二层输出内容特征,第一层和第三层输出风格特征。\n",
"\n",
"\n",
":label:`fig_style_transfer_model`\n",
"\n",
"接下来,我们通过前向传播(实线箭头方向)计算风格迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。\n",
"风格迁移常用的损失函数由3部分组成:\n",
"\n",
"1. *内容损失*使合成图像与内容图像在内容特征上接近;\n",
"1. *风格损失*使合成图像与风格图像在风格特征上接近;\n",
"1. *全变分损失*则有助于减少合成图像中的噪点。\n",
"\n",
"最后,当模型训练结束时,我们输出风格迁移的模型参数,即得到最终的合成图像。\n",
"\n",
"在下面,我们将通过代码来进一步了解风格迁移的技术细节。\n",
"\n",
"## [**阅读内容和风格图像**]\n",
"\n",
"首先,我们读取内容和风格图像。\n",
"从打印出的图像坐标轴可以看出,它们的尺寸并不一样。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a0d90f51",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:26.021505Z",
"iopub.status.busy": "2023-08-18T07:23:26.020759Z",
"iopub.status.idle": "2023-08-18T07:23:29.597245Z",
"shell.execute_reply": "2023-08-18T07:23:29.595990Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"d2l.set_figsize()\n",
"content_img = d2l.Image.open('../img/rainier.jpg')\n",
"d2l.plt.imshow(content_img);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ec590a65",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:29.601550Z",
"iopub.status.busy": "2023-08-18T07:23:29.600514Z",
"iopub.status.idle": "2023-08-18T07:23:30.096132Z",
"shell.execute_reply": "2023-08-18T07:23:30.095315Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"style_img = d2l.Image.open('../img/autumn-oak.jpg')\n",
"d2l.plt.imshow(style_img);"
]
},
{
"cell_type": "markdown",
"id": "ddc886a6",
"metadata": {
"origin_pos": 6
},
"source": [
"## [**预处理和后处理**]\n",
"\n",
"下面,定义图像的预处理函数和后处理函数。\n",
"预处理函数`preprocess`对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。\n",
"后处理函数`postprocess`则将输出图像中的像素值还原回标准化之前的值。\n",
"由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6f351192",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:30.103272Z",
"iopub.status.busy": "2023-08-18T07:23:30.102388Z",
"iopub.status.idle": "2023-08-18T07:23:30.112076Z",
"shell.execute_reply": "2023-08-18T07:23:30.111052Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"rgb_mean = torch.tensor([0.485, 0.456, 0.406])\n",
"rgb_std = torch.tensor([0.229, 0.224, 0.225])\n",
"\n",
"def preprocess(img, image_shape):\n",
" transforms = torchvision.transforms.Compose([\n",
" torchvision.transforms.Resize(image_shape),\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])\n",
" return transforms(img).unsqueeze(0)\n",
"\n",
"def postprocess(img):\n",
" img = img[0].to(rgb_std.device)\n",
" img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)\n",
" return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))"
]
},
{
"cell_type": "markdown",
"id": "54e9f5d8",
"metadata": {
"origin_pos": 10
},
"source": [
"## [**抽取图像特征**]\n",
"\n",
"我们使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征 :cite:`Gatys.Ecker.Bethge.2016`。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f562ee81",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:30.117700Z",
"iopub.status.busy": "2023-08-18T07:23:30.116834Z",
"iopub.status.idle": "2023-08-18T07:23:42.885822Z",
"shell.execute_reply": "2023-08-18T07:23:42.884582Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth\" to /home/ci/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth\n"
]
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.007704734802246094,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": null,
"postfix": null,
"prefix": "",
"rate": null,
"total": 574673361,
"unit": "B",
"unit_divisor": 1024,
"unit_scale": true
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a21ebb04f6e4afe9df09a7d7c6a0fe0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0.00/548M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pretrained_net = torchvision.models.vgg19(pretrained=True)"
]
},
{
"cell_type": "markdown",
"id": "85fcfc18",
"metadata": {
"origin_pos": 14
},
"source": [
"为了抽取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。\n",
"一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。\n",
"为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即*内容层*,来输出图像的内容特征。\n",
"我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为*风格层*。\n",
"正如 :numref:`sec_vgg`中所介绍的,VGG网络使用了5个卷积块。\n",
"实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。\n",
"这些层的索引可以通过打印`pretrained_net`实例获取。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9c71d01f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.891197Z",
"iopub.status.busy": "2023-08-18T07:23:42.890918Z",
"iopub.status.idle": "2023-08-18T07:23:42.895437Z",
"shell.execute_reply": "2023-08-18T07:23:42.894539Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"style_layers, content_layers = [0, 5, 10, 19, 28], [25]"
]
},
{
"cell_type": "markdown",
"id": "a02682f4",
"metadata": {
"origin_pos": 16
},
"source": [
"使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或风格层之间的所有层。\n",
"下面构建一个新的网络`net`,它只保留需要用到的VGG的所有层。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "75c742e5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.901208Z",
"iopub.status.busy": "2023-08-18T07:23:42.900137Z",
"iopub.status.idle": "2023-08-18T07:23:42.906883Z",
"shell.execute_reply": "2023-08-18T07:23:42.905548Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = nn.Sequential(*[pretrained_net.features[i] for i in\n",
" range(max(content_layers + style_layers) + 1)])"
]
},
{
"cell_type": "markdown",
"id": "932aec9e",
"metadata": {
"origin_pos": 19
},
"source": [
"给定输入`X`,如果我们简单地调用前向传播`net(X)`,只能获得最后一层的输出。\n",
"由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出。\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5cd1ccc0",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.911258Z",
"iopub.status.busy": "2023-08-18T07:23:42.910587Z",
"iopub.status.idle": "2023-08-18T07:23:42.929103Z",
"shell.execute_reply": "2023-08-18T07:23:42.927620Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def extract_features(X, content_layers, style_layers):\n",
" contents = []\n",
" styles = []\n",
" for i in range(len(net)):\n",
" X = net[i](X)\n",
" if i in style_layers:\n",
" styles.append(X)\n",
" if i in content_layers:\n",
" contents.append(X)\n",
" return contents, styles"
]
},
{
"cell_type": "markdown",
"id": "49db6d46",
"metadata": {
"origin_pos": 21
},
"source": [
"下面定义两个函数:`get_contents`函数对内容图像抽取内容特征;\n",
"`get_styles`函数对风格图像抽取风格特征。\n",
"因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。\n",
"由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用`extract_features`函数来抽取合成图像的内容特征和风格特征。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f80b015d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.934583Z",
"iopub.status.busy": "2023-08-18T07:23:42.933552Z",
"iopub.status.idle": "2023-08-18T07:23:42.955842Z",
"shell.execute_reply": "2023-08-18T07:23:42.954343Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_contents(image_shape, device):\n",
" content_X = preprocess(content_img, image_shape).to(device)\n",
" contents_Y, _ = extract_features(content_X, content_layers, style_layers)\n",
" return content_X, contents_Y\n",
"\n",
"def get_styles(image_shape, device):\n",
" style_X = preprocess(style_img, image_shape).to(device)\n",
" _, styles_Y = extract_features(style_X, content_layers, style_layers)\n",
" return style_X, styles_Y"
]
},
{
"cell_type": "markdown",
"id": "0a04d737",
"metadata": {
"origin_pos": 25
},
"source": [
"## [**定义损失函数**]\n",
"\n",
"下面我们来描述风格迁移的损失函数。\n",
"它由内容损失、风格损失和全变分损失3部分组成。\n",
"\n",
"### 内容损失\n",
"\n",
"与线性回归中的损失函数类似,内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异。\n",
"平方误差函数的两个输入均为`extract_features`函数计算所得到的内容层的输出。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1048e5c2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.961660Z",
"iopub.status.busy": "2023-08-18T07:23:42.961194Z",
"iopub.status.idle": "2023-08-18T07:23:42.967349Z",
"shell.execute_reply": "2023-08-18T07:23:42.966138Z"
},
"origin_pos": 27,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def content_loss(Y_hat, Y):\n",
" # 我们从动态计算梯度的树中分离目标:\n",
" # 这是一个规定的值,而不是一个变量。\n",
" return torch.square(Y_hat - Y.detach()).mean()"
]
},
{
"cell_type": "markdown",
"id": "71b083a8",
"metadata": {
"origin_pos": 29
},
"source": [
"### 风格损失\n",
"\n",
"风格损失与内容损失类似,也通过平方误差函数衡量合成图像与风格图像在风格上的差异。\n",
"为了表达风格层输出的风格,我们先通过`extract_features`函数计算风格层的输出。\n",
"假设该输出的样本数为1,通道数为$c$,高和宽分别为$h$和$w$,我们可以将此输出转换为矩阵$\\mathbf{X}$,其有$c$行和$hw$列。\n",
"这个矩阵可以被看作由$c$个长度为$hw$的向量$\\mathbf{x}_1, \\ldots, \\mathbf{x}_c$组合而成的。其中向量$\\mathbf{x}_i$代表了通道$i$上的风格特征。\n",
"\n",
"在这些向量的*格拉姆矩阵*$\\mathbf{X}\\mathbf{X}^\\top \\in \\mathbb{R}^{c \\times c}$中,$i$行$j$列的元素$x_{ij}$即向量$\\mathbf{x}_i$和$\\mathbf{x}_j$的内积。它表达了通道$i$和通道$j$上风格特征的相关性。我们用这样的格拉姆矩阵来表达风格层输出的风格。\n",
"需要注意的是,当$hw$的值较大时,格拉姆矩阵中的元素容易出现较大的值。\n",
"此外,格拉姆矩阵的高和宽皆为通道数$c$。\n",
"为了让风格损失不受这些值的大小影响,下面定义的`gram`函数将格拉姆矩阵除以了矩阵中元素的个数,即$chw$。\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "207704c1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.973259Z",
"iopub.status.busy": "2023-08-18T07:23:42.971937Z",
"iopub.status.idle": "2023-08-18T07:23:42.979314Z",
"shell.execute_reply": "2023-08-18T07:23:42.978380Z"
},
"origin_pos": 30,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def gram(X):\n",
" num_channels, n = X.shape[1], X.numel() // X.shape[1]\n",
" X = X.reshape((num_channels, n))\n",
" return torch.matmul(X, X.T) / (num_channels * n)"
]
},
{
"cell_type": "markdown",
"id": "3c362780",
"metadata": {
"origin_pos": 31
},
"source": [
"自然地,风格损失的平方误差函数的两个格拉姆矩阵输入分别基于合成图像与风格图像的风格层输出。这里假设基于风格图像的格拉姆矩阵`gram_Y`已经预先计算好了。\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3491c1fd",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.984195Z",
"iopub.status.busy": "2023-08-18T07:23:42.983439Z",
"iopub.status.idle": "2023-08-18T07:23:42.988675Z",
"shell.execute_reply": "2023-08-18T07:23:42.987781Z"
},
"origin_pos": 33,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def style_loss(Y_hat, gram_Y):\n",
" return torch.square(gram(Y_hat) - gram_Y.detach()).mean()"
]
},
{
"cell_type": "markdown",
"id": "44caeefb",
"metadata": {
"origin_pos": 35
},
"source": [
"### 全变分损失\n",
"\n",
"有时候,我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。\n",
"一种常见的去噪方法是*全变分去噪*(total variation denoising):\n",
"假设$x_{i, j}$表示坐标$(i, j)$处的像素值,降低全变分损失\n",
"\n",
"$$\\sum_{i, j} \\left|x_{i, j} - x_{i+1, j}\\right| + \\left|x_{i, j} - x_{i, j+1}\\right|$$\n",
"\n",
"能够尽可能使邻近的像素值相似。\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2173f076",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:42.994009Z",
"iopub.status.busy": "2023-08-18T07:23:42.992920Z",
"iopub.status.idle": "2023-08-18T07:23:43.000113Z",
"shell.execute_reply": "2023-08-18T07:23:42.998890Z"
},
"origin_pos": 36,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def tv_loss(Y_hat):\n",
" return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +\n",
" torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())"
]
},
{
"cell_type": "markdown",
"id": "335a9e26",
"metadata": {
"origin_pos": 37
},
"source": [
"### 损失函数\n",
"\n",
"[**风格转移的损失函数是内容损失、风格损失和总变化损失的加权和**]。\n",
"通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "7b2d722a",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:43.004890Z",
"iopub.status.busy": "2023-08-18T07:23:43.004502Z",
"iopub.status.idle": "2023-08-18T07:23:43.012475Z",
"shell.execute_reply": "2023-08-18T07:23:43.011392Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"content_weight, style_weight, tv_weight = 1, 1e3, 10\n",
"\n",
"def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):\n",
" # 分别计算内容损失、风格损失和全变分损失\n",
" contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(\n",
" contents_Y_hat, contents_Y)]\n",
" styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(\n",
" styles_Y_hat, styles_Y_gram)]\n",
" tv_l = tv_loss(X) * tv_weight\n",
" # 对所有损失求和\n",
" l = sum(10 * styles_l + contents_l + [tv_l])\n",
" return contents_l, styles_l, tv_l, l"
]
},
{
"cell_type": "markdown",
"id": "9f90235c",
"metadata": {
"origin_pos": 39
},
"source": [
"## [**初始化合成图像**]\n",
"\n",
"在风格迁移中,合成的图像是训练期间唯一需要更新的变量。因此,我们可以定义一个简单的模型`SynthesizedImage`,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可。\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a4f99f98",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:43.018222Z",
"iopub.status.busy": "2023-08-18T07:23:43.017815Z",
"iopub.status.idle": "2023-08-18T07:23:43.031792Z",
"shell.execute_reply": "2023-08-18T07:23:43.030715Z"
},
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class SynthesizedImage(nn.Module):\n",
" def __init__(self, img_shape, **kwargs):\n",
" super(SynthesizedImage, self).__init__(**kwargs)\n",
" self.weight = nn.Parameter(torch.rand(*img_shape))\n",
"\n",
" def forward(self):\n",
" return self.weight"
]
},
{
"cell_type": "markdown",
"id": "f7a98b0a",
"metadata": {
"origin_pos": 43
},
"source": [
"下面,我们定义`get_inits`函数。该函数创建了合成图像的模型实例,并将其初始化为图像`X`。风格图像在各个风格层的格拉姆矩阵`styles_Y_gram`将在训练前预先计算好。\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3055aa3b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:43.037767Z",
"iopub.status.busy": "2023-08-18T07:23:43.037371Z",
"iopub.status.idle": "2023-08-18T07:23:43.044082Z",
"shell.execute_reply": "2023-08-18T07:23:43.043078Z"
},
"origin_pos": 45,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_inits(X, device, lr, styles_Y):\n",
" gen_img = SynthesizedImage(X.shape).to(device)\n",
" gen_img.weight.data.copy_(X.data)\n",
" trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)\n",
" styles_Y_gram = [gram(Y) for Y in styles_Y]\n",
" return gen_img(), styles_Y_gram, trainer"
]
},
{
"cell_type": "markdown",
"id": "6b86cbc2",
"metadata": {
"origin_pos": 47
},
"source": [
"## [**训练模型**]\n",
"\n",
"在训练模型进行风格迁移时,我们不断抽取合成图像的内容特征和风格特征,然后计算损失函数。下面定义了训练循环。\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4fe3c6e1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:43.048455Z",
"iopub.status.busy": "2023-08-18T07:23:43.048174Z",
"iopub.status.idle": "2023-08-18T07:23:43.056676Z",
"shell.execute_reply": "2023-08-18T07:23:43.055451Z"
},
"origin_pos": 49,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):\n",
" X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)\n",
" scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)\n",
" animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
" xlim=[10, num_epochs],\n",
" legend=['content', 'style', 'TV'],\n",
" ncols=2, figsize=(7, 2.5))\n",
" for epoch in range(num_epochs):\n",
" trainer.zero_grad()\n",
" contents_Y_hat, styles_Y_hat = extract_features(\n",
" X, content_layers, style_layers)\n",
" contents_l, styles_l, tv_l, l = compute_loss(\n",
" X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)\n",
" l.backward()\n",
" trainer.step()\n",
" scheduler.step()\n",
" if (epoch + 1) % 10 == 0:\n",
" animator.axes[1].imshow(postprocess(X))\n",
" animator.add(epoch + 1, [float(sum(contents_l)),\n",
" float(sum(styles_l)), float(tv_l)])\n",
" return X"
]
},
{
"cell_type": "markdown",
"id": "7d70c1ac",
"metadata": {
"origin_pos": 51
},
"source": [
"现在我们[**训练模型**]:\n",
"首先将内容图像和风格图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c0846fe5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:23:43.061124Z",
"iopub.status.busy": "2023-08-18T07:23:43.060316Z",
"iopub.status.idle": "2023-08-18T07:24:35.646273Z",
"shell.execute_reply": "2023-08-18T07:24:35.645421Z"
},
"origin_pos": 53,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"device, image_shape = d2l.try_gpu(), (300, 450)\n",
"net = net.to(device)\n",
"content_X, contents_Y = get_contents(image_shape, device)\n",
"_, styles_Y = get_styles(image_shape, device)\n",
"output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)"
]
},
{
"cell_type": "markdown",
"id": "7f5c2480",
"metadata": {
"origin_pos": 55
},
"source": [
"我们可以看到,合成图像保留了内容图像的风景和物体,并同时迁移了风格图像的色彩。例如,合成图像具有与风格图像中一样的色彩块,其中一些甚至具有画笔笔触的细微纹理。\n",
"\n",
"## 小结\n",
"\n",
"* 风格迁移常用的损失函数由3部分组成:(1)内容损失使合成图像与内容图像在内容特征上接近;(2)风格损失令合成图像与风格图像在风格特征上接近;(3)全变分损失则有助于减少合成图像中的噪点。\n",
"* 我们可以通过预训练的卷积神经网络来抽取图像的特征,并通过最小化损失函数来不断更新合成图像来作为模型参数。\n",
"* 我们使用格拉姆矩阵表达风格层输出的风格。\n",
"\n",
"## 练习\n",
"\n",
"1. 选择不同的内容和风格层,输出有什么变化?\n",
"1. 调整损失函数中的权重超参数。输出是否保留更多内容或减少更多噪点?\n",
"1. 替换实验中的内容图像和风格图像,能创作出更有趣的合成图像吗?\n",
"1. 我们可以对文本使用风格迁移吗?提示:可以参阅调查报告 :cite:`Hu.Lee.Aggarwal.ea.2020`。\n"
]
},
{
"cell_type": "markdown",
"id": "8888edcd",
"metadata": {
"origin_pos": 57,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/3300)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": [],
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {
"0a21ebb04f6e4afe9df09a7d7c6a0fe0": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_3386849c303b40d18895bb91db97e325",
"IPY_MODEL_155d363cdf40442e8faf86c2f0def49d",
"IPY_MODEL_6840bc285801445eafe45e9cfc4a3216"
],
"layout": "IPY_MODEL_96dfbfe851a544c09ae13797ba4d4198",
"tabbable": null,
"tooltip": null
}
},
"155d363cdf40442e8faf86c2f0def49d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_6e8fb617074d452eb01dbfe715d3827f",
"max": 574673361.0,
"min": 0.0,
"orientation": "horizontal",
"style": "IPY_MODEL_a24b158989e04d0e90ba67a8b670ce52",
"tabbable": null,
"tooltip": null,
"value": 574673361.0
}
},
"2d45b8677d764d7f9e04f2ab4f38d40b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3386849c303b40d18895bb91db97e325": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_2d45b8677d764d7f9e04f2ab4f38d40b",
"placeholder": "",
"style": "IPY_MODEL_bc6bcbf06ef44f2585d8d82ac182ea56",
"tabbable": null,
"tooltip": null,
"value": "100%"
}
},
"5223c9213fef443497e360398c21149f": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
},
"57464fe9afd448b0a23eee081a9e085d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6840bc285801445eafe45e9cfc4a3216": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "2.0.0",
"_view_name": "HTMLView",
"description": "",
"description_allow_html": false,
"layout": "IPY_MODEL_57464fe9afd448b0a23eee081a9e085d",
"placeholder": "",
"style": "IPY_MODEL_5223c9213fef443497e360398c21149f",
"tabbable": null,
"tooltip": null,
"value": " 548M/548M [00:10<00:00, 69.9MB/s]"
}
},
"6e8fb617074d452eb01dbfe715d3827f": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"96dfbfe851a544c09ae13797ba4d4198": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "2.0.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "2.0.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border_bottom": null,
"border_left": null,
"border_right": null,
"border_top": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a24b158989e04d0e90ba67a8b670ce52": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"bc6bcbf06ef44f2585d8d82ac182ea56": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "2.0.0",
"model_name": "HTMLStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "2.0.0",
"_model_name": "HTMLStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "2.0.0",
"_view_name": "StyleView",
"background": null,
"description_width": "",
"font_size": null,
"text_color": null
}
}
},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}