强化学习实战——Q-Learing和SASAR悬崖探宝

in #cn-stem5 years ago

Stormtrooper minifigure walking on the sand

image source from unsplash.com by Daniel Cheung

之前我们介绍了Q-learning和SASAR算法的理论,这篇文章就理论结合实际用Q-learning 和SASAR算法指导智能体,完成悬崖探宝任务。

同样的,为了方便与读者交流,所有的代码都放在了这里:

https://github.com/zht007/tensorflow-practice

1. 环境简介

智能体在下图4 *12的格子世界中活动,"x"代表起点和智能体当前的位置,"T"代表终点,"C"代表悬崖,"o"代表其他位置。

o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
x  C  C  C  C  C  C  C  C  C  C  T

状态(States):当前位置

奖励(Rewards):终点为1,悬崖为-100,其他地方为1

行动(Action):上下左右,四个方向

2. SARSA算法

SARSA全称State–Action–Reward–State–Action,是on-policy的算法,即只有一个策略指挥行动并同时被更新。顾名思义,该算法需要5个数据,当前的 state, reward,action下一步state和action。两步action和state均由epsilon greedy策略指导。

2.1 定义epsilon greedy的策略

为了保证On-Policy的算法能访问到所有的状态,SARSA所有的行动策略必须是epsilon greedy的,这里定义的epsilon greedy策略与前文中是一样的。

def make_epsilon_greedy_policy(Q, epsilon, nA):
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA
        best_action = np.argmax(Q[observation])
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

该部分代码参考github with MIT license

2.2 定义SARSA算法

首先,根据当前策略迈出第一步,获得当前的S和A

policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)
for i_episode in range(num_episodes):
        # First action
        state = env.reset()
        action_probs = policy(state)
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)

该部分代码参考github with MIT license

然后,进入循环,直到游戏结束(if done: break)。该循环是为了获得当前的R下一步的S‘,和A‘。并带入公式更新Q(S ,A),由于策略是通过Q(S, A)生成的,所以更新Q(S,A)的同时,策略也更新了。
$$
Q(S, A) \leftarrow Q(S, A)+\alpha\left[R+\gamma Q\left(S^{\prime}, A^{\prime}\right)-Q(S, A)\right]
$$

                while True:
            next_state, reward, done, _ = env.step(action)
            
            next_action_probs = policy(next_state)         
            next_action = np.random.choice(np.arange(len(action_probs)),         p=next_action_probs)
            
            Q[state][action] += alpha * (reward + discount_factor * Q[next_state][next_action] - Q[state][action])
            
            if done:
                break
            state = next_state
            action = next_action 
   return Q

该部分代码参考github with MIT license

将当前S和A替换成下一步S‘和A‘,直到游戏结束,最终得到优化后的Q表。

3. Q-Learning算法

Q-learning 与 SASAR有非常多的相似之处,但是本质上,Q-learning是Off-Policy的算法。也就是说Q-learning有两套Policy,Behavior Policy 和 Target Policy, 一个用于探索另一个用于优化。

与SARSA一样Q-Learning 也需要定义相同的epsilon greedy的策略,这里略过,我们看看算法本身的代码。

policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)
    
    for i_episode in range(num_episodes):
        state = env.reset()
        
        while True:
            action_probs = policy(state)
            action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
            next_state, reward, done, _ = env.step(action)
            best_action = np.argmax(Q[state])
            
            Q[state][action] += alpha * (reward + discount_factor * Q[next_state][best_action] - Q[state][action])
                  
            if done:
                break
            state = next_state
    
    return Q

该部分代码参考github with MIT license

该算法与SARSA的区别是,Q-Learning在行动的时候采用epsilon greedy的策略(Behavior Policy),但是在更新 Target Policy 的Q(S,A)时候,采用的是greedy的策略,即下一步的最大回报(best_action = np.argmax(Q[state]))

4. 总结

上文介绍的SASAR和Q-learning都属于单步Temporal Difference (时间差分TD(0))算法,其通用的更新公式为

Q[s, a]+=learning_rate *(td_target-Q[s, a])

其中 td_target - Q[s,a] 部分又叫做 TD Error.

SARSA算法:

td_target =R[t+1] + discout_factor*Q[s',a']

Q-learning:

td_target =R[t+1] + discout_factor*max(Q[s'])

关于两个算法的对比,我们可以看看两者最终的行动轨迹,首先定义渲染函数

def render_evn(Q):
    state = env.reset()
    while True:
        next_state, reward, done, _  = env.step(np.argmax(Q[state]))
        env.render()
        if done:
            break
        state = next_state

SARSA算法

Q1, stats1 = sarsa(env, 500)
render_evn(Q1)

----output---
x  x  x  x  x  x  x  x  x  x  x  x
x  o  o  o  o  o  o  o  o  o  o  x
x  o  o  o  o  o  o  o  o  o  o  x
x  C  C  C  C  C  C  C  C  C  C  T

Q-learing

Q1, stats1 = sarsa(env, 500)
render_evn(Q1)

----output---
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
x  x  x  x  x  x  x  x  x  x  x  x
x  C  C  C  C  C  C  C  C  C  C  T

可以看出Q-learning得到的policy是沿着悬崖的最短(最佳)路径,获得的奖励最多,然而这样做却十分危险,因为在行动中由于采用的epsilon greedy的策略,有一定的几率掉进悬崖。SARSA算法由于是On-policy的在更新的时候意识到了掉进悬崖的危险,所以它选择了一条更加安全的路径,即多走两步,绕开悬崖。


参考资料

[1] Reinforcement Learning: An Introduction (2nd Edition)

[2] David Silver's Reinforcement Learning Course (UCL, 2015)

[3] Github repo: Reinforcement Learning


相关文章

强化学习——MC(蒙特卡洛)玩21点扑克游戏

强化学习实战——动态规划(DP)求最优MDP

强化学习——强化学习的算法分类

强化学习——重拾强化学习的核心概念

AI学习笔记——Sarsa算法

AI学习笔记——Q Learning

AI学习笔记——动态规划(Dynamic Programming)解决MDP(1)

AI学习笔记——动态规划(Dynamic Programming)解决MDP(2)

AI学习笔记——MDP(Markov Decision Processes马可夫决策过程)简介

AI学习笔记——求解最优MDP


同步到我的简书
https://www.jianshu.com/u/bd506afc6fc1

Sort:  

吃了吗?这是哪里?你是谁?我为什么会来这边?你不要给我点赞不要点赞,哈哈哈哈哈哈。如果不想再收到我的留言,请回复“取消”。



This post has been voted on by the SteemSTEM curation team and voting trail. It is elligible for support from @curie.

If you appreciate the work we are doing, then consider supporting our witness stem.witness. Additional witness support to the curie witness would be appreciated as well.

For additional information please join us on the SteemSTEM discord and to get to know the rest of the community!

Please consider setting @steemstem as a beneficiary to your post to get a stronger support.

Please consider using the steemstem.io app to get a stronger support.

Coin Marketplace

STEEM 0.19
TRX 0.15
JST 0.029
BTC 62676.37
ETH 2581.43
USDT 1.00
SBD 2.72