博客
关于我
自编码器模型详解与实现(采用tensorflow2.x实现)
阅读量:783 次
发布时间:2019-03-25

本文共 7897 字,大约阅读时间需要 26 分钟。

自编码器模型详解与实现(采用TensorFlow 2.x 实现)

1. 自编码器与潜变量学习概述

自编码器是一种在无监督学习中广泛应用的深度学习模型,由Geoffrey Hinton等人于1980年代首次提出。它的核心目标是通过压缩高维输入空间到低维潜变量Representation(以下简称“潜”),并在解码阶段将这些潜还原为原始高维输入。这种能力使其在图像处理、材质分离等领域具有重要应用价值。

在图像处理领域,自编码器可以类比于数据压缩与解压过程。例如,就如JPEG将高分辨率图像压缩为小文件格式一样,自编码器则可以将原始图像压缩为低维潜变量,再通过解码器还原回高分辨率图像。这使得自编码器成为一种高效的图像压缩与恢复工具。

2. 自编码器架构详解

2.1 编码器设计

编码器负责将高维输入通过一系列的神经网络层压缩为低维潜变量。我们以MNIST数据集为例,构建一个适用于28x28x1输入尺寸的编码器。潜变量的维度设置为低于输入维度的超参数,这里采用10维。

以下是编码器的实现代码框架:

def Encoder(z_dim):    inputs = layers.Input(shape=[28, 28, 1])    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Flatten()(x)    out = Dense(z_dim, activation='relu')(x)    return Model(inputs=inputs, outputs=out, name='encoder')

编码器主要包含卷积层和全连接层。卷积层用于提取高层次的特征,同时通过调整卷积核的步长(如2)实现特征图的下采样,逐步减少输入的高维信息。全连接层则负责将多个特征图融合到低维潜变量空间中。

2.2 解码器设计

解码器的任务是将低维潜变量还原为高维图像。其结构与编码器相似,但需在解码过程中通过卷积层和上采样操作逐步还原特征图。

以下是解码器的实现代码框架:

def Decoder(z_dim):    inputs = layers.Input(shape=[z_dim])    x = Dense(7*7*64, activation='relu')(x)    x = Reshape((7,7,64))(x)    x = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    x = Conv2D(filters=32, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    out = Conv2D(filters=1, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')(x)    return Model(inputs=inputs, outputs=out, name='decoder')

解码器通过卷积层在低维空间中生成特征图,并结合上采样操作逐渐将特征图还原为原始图像尺寸。上采样方法包括卷积核转置(例如UpSampling2D)或仿射变换,但后者是不训练的参数,通常不适合深度学习模型。

3. 自编码器模型构建

将编码器和解码器组合,构建完整的自编码器模型:

z_dim = 10encoder = Encoder(z_dim)decoder = Decoder(z_dim)model_input = encoder.inputmodel_output = decoder(encoder.output)autoencoder = Model(model_input, model_output)
3.1 模型训练

为了训练模型,我们采用MSE(均方误差)损失函数,旨在最小化编码器输出与解码器预测值之间的差异。同时,使用一些训练回调(如ModelCheckpointEarlyStopping)来优化训练过程。

autoencoder.compile(loss='mse', optimizer='rmsprop', lr=3e-4)

训练过程中,我们需要分成训练集和验证集,定期保存最佳模型参数以防止过拟合。

4. 从潜变量生成图像

自编码器的潜变量具有潜在的生成能力。比如,如果我们定义另一个解码器仅使用潜变量生成图像,可以利用这个能力进行高效的图像生成。

z_dim = 2  # 定义更低维的潜变量autoencoder_2 = Autoencoder(z_dim=2)

通过对潜变量进行采样,可以生成大量不同样本。如上图所示,我们采用2维潜变量空间,生成500个样本,散布在二维平面上。通过观察标签分布图,可以发现某些类别的潜变量代表性较强,而另一些类别则相对模糊。

更进一步地,我们可以通过滑动窗口或交互式工具(如下图所示),进行潜变量的可视化和探索。

from ipywidgets import interact, interact_manual@interactdef explore_latent_variable(z1=(-5,5,0.1), z2=(-5,5,0.1)):    z_samples = np.array([[z1, z2] for z2 in np.arange(-5,5,0.1)] for z1 in np.arange(-5,5,0.1))    images = autoencoder_2.decoder.predict(z_samples)    plt.figure(figsize=(2,2))    plt.imshow(images[0,:,:,0], cmap='gray')

完整代码示例

import tensorflow as tffrom tensorflow.keras import layers, Modelfrom tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Reshape, Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLUfrom tensorflow.keras.activations import relufrom tensorflow.keras.models import Sequential, load_modelfrom tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingimport tensorflow_datasets as tfdsimport numpy as npimport matplotlib.pyplot as pltimport warningswarnings.filterwarnings('ignore')print(tf.__version__)# 加载MNIST数据集(ds_train, ds_test), ds_info = tfds.load(    'mnist',    split=['train', 'test'],    shuffle_files=True,    as_supervised=True,    with_info=True)# 预处理数据def preprocess(image, label):    image = tf.cast(image, tf.float32)    image = image / 255.    return image, imageds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(batch_size, drop_remainder=True)ds_test = ds_test.cache().batch(batch_size, drop_remainder=True).prefetch(batch_size)def Encoder(z_dim):    inputs = layers.Input(shape=[28, 28, 1])    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Flatten()(x)    out = Dense(z_dim, activation='relu')(x)    return Model(inputs=inputs, outputs=out, name='encoder')def Decoder(z_dim):    inputs = layers.Input(shape=[z_dim])    x = Dense(7*7*64, activation='relu')(x)    x = Reshape((7,7,64))(x)    x = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    x = Conv2D(filters=32, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    out = Conv2D(filters=1, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')(x)    return Model(inputs=inputs, outputs=out, name='decoder')class Autoencoder:    def __init__(self, z_dim):        self.encoder = Encoder(z_dim)        self.decoder = Decoder(z_dim)        self.model_input = self.encoder.input        self.model_output = self.decoder(self.model_input)        self.model = Model(self.model_input, self.model_output)autoencoder = Autoencoder(z_dim=10)# 训练设置model_path = 'autoencoder.h5'checkpoint = ModelCheckpoint(model_path,                         monitor="val_loss",                        verbose=1,                        save_best_only=True,                         mode="auto",                        save_weights_only=False)early = EarlyStopping(monitor="val_loss",                      mode="auto",                      patience=5)callbacks_list = [checkpoint, early]autoencoder.model.compile(loss='mse',                        optimizer='rmsprop',                        lr=3e-4)autoencoder.model.fit(ds_train,                       validation_data=ds_test,                       epochs=100,                       callbacks=callbacks_list)# 加载预训练模型autoencoder.model = load_model(model_path)images, labels = next(iter(ds_test))outputs = autoencoder.model.predict(images)# 显示恢复后的图像plt.figure(figsize=(10, 2))for i in range(0, 64, 2):    plt.figure(figsize=(5, 2))    for j in range(2):        ax = plt.subplot(2, 5, j + i*2)  # 调整图像位置        ax.imshow(images[i, j], cmap='gray')        ax.axis('off')    plt.show()autoencoder_2 = Autoencoder(z_dim=2)model_path_2 = 'autoencoder_2.h5' checkpoint_2 = ModelCheckpoint(model_path_2,                             monitor="val_loss",                             verbose=1,                             save_best_only=True,                              mode="auto",                             save_weights_only=False) early_2 = EarlyStopping(monitor="val_loss",                        mode="auto",                        patience=5) callbacks_list_2 = [checkpoint_2, early_2]autoencoder_2.model.compile(loss="mse",                        optimizer='rmsprop',                        lr=1e-3)autoencoder_2.model.fit(ds_train,                       validation_data=ds_test,                       epochs=50,                       callbacks=callbacks_list_2)images_2, labels_2 = next(iter(ds_test))# 观察潜变量分布encoder_outputs = autoencoder_2.encoder.predict(images_2)plt.figure(figsize=(8, 8))plt.scatter(encoder_outputs[:, 0], encoder_outputs[:, 1], c=labels_2, cmap='RdYlBu', s=3)plt.colorbar()plt.show()z_samples = np.array([[z1, z2] for z1 in np.arange(-5,5,1.) for z2 in np.arange(-5,5,1.)])decoded_images = autoencoder_2.decoder.predict(z_samples)plt.figure(figsize=(10, 10))for i in range(100):    plt.figure(figsize=(5,5))    for j in range(10):        ax = plt.subplot(10, 10, i*10 + j + 1)        ax.imshow(decoded_images[i, j], cmap='gray')        ax.axis('off')plt.show()

结论

通过以上实现,我们成功构建并训练了一个自编码器模型,能够将MNIST数据集中的影像压缩为低维潜变量并还原回高分辨率图像。这种模型不仅能够实现图像压缩,还可以用于图像去噪、风格迁移等多种任务。通过探索潜变量空间,我们还可以发现输入数据中的潜在特征分布,以进一步提升模型性能和应用效果。

转载地址:http://mynuk.baihongyu.com/

你可能感兴趣的文章
Mysql 整形列的字节与存储范围
查看>>
mysql 断电数据损坏,无法启动
查看>>
MySQL 日期时间类型的选择
查看>>
Mysql 时间操作(当天,昨天,7天,30天,半年,全年,季度)
查看>>
MySQL 是如何加锁的?
查看>>
MySQL 是怎样运行的 - InnoDB数据页结构
查看>>
mysql 更新子表_mysql 在update中实现子查询的方式
查看>>
MySQL 有什么优点?
查看>>
mysql 权限整理记录
查看>>
mysql 权限登录问题:ERROR 1045 (28000): Access denied for user ‘root‘@‘localhost‘ (using password: YES)
查看>>
MYSQL 查看最大连接数和修改最大连接数
查看>>
MySQL 查看有哪些表
查看>>
mysql 查看锁_阿里/美团/字节面试官必问的Mysql锁机制,你真的明白吗
查看>>
MySql 查询以逗号分隔的字符串的方法(正则)
查看>>
MySQL 查询优化:提速查询效率的13大秘籍(避免使用SELECT 、分页查询的优化、合理使用连接、子查询的优化)(上)
查看>>
mysql 查询数据库所有表的字段信息
查看>>
【Java基础】什么是面向对象?
查看>>
mysql 查询,正数降序排序,负数升序排序
查看>>
MySQL 树形结构 根据指定节点 获取其下属的所有子节点(包含路径上的枝干节点和叶子节点)...
查看>>
mysql 死锁 Deadlock found when trying to get lock; try restarting transaction
查看>>