返回介绍

DQN 算法更新 (Tensorflow)

发布于 2025-05-02 13:36:21 字数 3511 浏览 0 评论 0 收藏

作者: Morvan 编辑: Morvan

Deep Q Network 的简称叫 DQN, 是将 Q learning 的优势 和 Neural networks 结合了. 如果我们使用 tabular Q learning, 对于每一个 state, action 我们都需要存放在一张 q_table 的表中. 如果像显示生活中,情况可就比那个迷宫的状况复杂多了,我们有千千万万个 state, 如果将这千万个 state 的值都放在表中,受限于我们计算机硬件,这样从表中获取数据,更新数据是没有效率的. 这就是 DQN 产生的原因了. 我们可以使用神经网络来 估算 这个 state 的值,这样就不需要一张表了。

这次的教程我们还是基于熟悉的 迷宫 环境,重点在实现 DQN 算法,之后我们再拿着做好的 DQN 算法去跑其他更有意思的环境。

本节内容包括:

算法

整个算法乍看起来很复杂,不过我们拆分一下,就变简单了. 也就是个 Q learning 主框架上加了些装饰。

这些装饰包括:

  • 记忆库 (用于重复学习)
  • 神经网络计算 Q 值
  • 暂时冻结 q_target 参数 (切断相关性)

算法的代码形式

接下来我们对应上面的算法,来实现主循环. 首先 import 所需模块。

from maze_env import Maze
from RL_brain import DeepQNetwork

下面的代码,就是 DQN 于环境交互最重要的部分。

def run_maze():
    step = 0    # 用来控制什么时候学习
    for episode in range(300):
        # 初始化环境
        observation = env.reset()

        while True:
            # 刷新环境
            env.render()

            # DQN 根据观测值选择行为
            action = RL.choose_action(observation)

            # 环境根据行为给出下一个 state, reward, 是否终止
            observation_, reward, done = env.step(action)

            # DQN 存储记忆
            RL.store_transition(observation, action, reward, observation_)

            # 控制学习起始时间和频率 (先累积一些记忆再开始学习)
            if (step > 200) and (step % 5 == 0):
                RL.learn()

            # 将下一个 state_ 变为 下次循环的 state
            observation = observation_

            # 如果终止,就跳出循环
            if done:
                break
            step += 1   # 总步数

    # end of game
    print('game over')
    env.destroy()


if __name__ == "__main__":
    env = Maze()
    RL = DeepQNetwork(env.n_actions, env.n_features,
                      learning_rate=0.01,
                      reward_decay=0.9,
                      e_greedy=0.9,
                      replace_target_iter=200,  # 每 200 步替换一次 target_net 的参数
                      memory_size=2000, # 记忆上限
                      # output_graph=True   # 是否输出 tensorboard 文件
                      )
    env.after(100, run_maze)
    env.mainloop()
    RL.plot_cost()  # 观看神经网络的误差曲线

下一节我们会来讲解 DeepQNetwork 这种算法具体要怎么编。

如果想一次性看到全部代码,请去我的 Github

如果你觉得这篇文章或视频对你的学习很有帮助,请你也分享它,让它能再次帮助到更多的需要学习的人。

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。