SU Library

[강화학습] On-policy(SARSA)와 Off-policy(Q-learning) 본문

인공지능/강화학습

[강화학습] On-policy(SARSA)와 Off-policy(Q-learning)

S U 2024. 6. 2. 15:32

강화학습을 공부하다보면 환경에 최적화된 에이전드a를 구축하는 것이 가장 큰 목표입니다. 여기서 에이전트의 행동을 결정하는데 기여하는 것이 정책이라는 π 입니다. 또한, 매 에피소드마다 행동에 따른 결과 값을 리턴 값r이라 부릅니다. 이러한 최적 정책을 갖게하는 데 있어서 크게 On-policy와 Off-policy의 방법으로 나눠서 에이전트를 학습할 수 있습니다. 이 두방식은 학습과정에서 생성되는 데이터를 활용하여 최적의 정책을 구축하려는 궁극적인 목표는 동일하나 다음과 같은 두가지 차이점이 있습니다.

 

On/off policy

 

  • On-policy 방법에서는 학습 과정에서 생성된 데이터를 이용하여 현재의 정책을 평가하고 개선합니다. 즉, 현재 정책에 따라 행동을 선택하고, 그 행동의 결과로 받은 보상과 다음 상태를 사용하여 같은 정책을 업데이트합니다. 이 방법에서는 학습 정책과 평가 정책이 동일합니다.
  • Off-policy 방법에서는 학습 과정에서 생성된 데이터를 이용하여 다른 정책을 평가하고 개선합니다. 즉, 하나의 정책(행동 정책)으로부터 데이터를 수집하지만, 수집된 데이터를 사용하여 다른 정책(타겟 정책)을 업데이트합니다. 이 방법에서는 학습 정책과 평가 정책이 다를 수 있습니다.

둘은 유연성과 탐험 및 활용에서 다음의 두가지 차이점을 갖습니다.

  • 유연성: Off-policy는 다른 정책으로부터 얻은 데이터를 사용하여 정책을 평가하고 개선할 수 있기 때문에 더 유연합니다. 이는 예를 들어, 과거에 수집된 데이터나 다른 에이전트의 데이터를 재사용할 수 있다는 것을 의미합니다.
  • 탐험과 활용: On-policy 방법은 현재 정책에 따라 탐험과 활용의 균형을 맞춰야 하며, Off-policy 방법은 탐험 정책과 별개로 최적의 정책을 학습할 수 있습니다.

이러한 차이점 때문에 효율성과 최적화된 솔루션을 찾는 쪽에서 Off-policy가 더 유리하고, 최신 논문들도 Off-policy 기반의 모델들을 소개하고 있습니다. On-policy의 대표적인 알고리즘은 SARSA가 있습니다. 반면 Off-policy의 대표적인 알고리즘은 Q-learning이 있습니다. 이둘은 벨만 기대 방정식과 벨만 최적 방정식으로 매우 유사한 방식으로 동작합니다. 

그러면 수식적으로 어떤차이가 있는지 알아보겠습니다.

SARSA:Q(S,A)Q(S,A)+α(R+γQ(S,A)Q(S,A))

수식에서 알 수 있듯이 실제로 선택한 다음 행동을 기반으로 Q 값을 업데이트하는 특징이 있습니다. 따라서 실제로 수행한 행동을 기반으로 학습하는 특징이 있습니다. 

Qlearning:Q(S,A)Q(S,A)+α(R+γmaxAQ(S,A)Q(S,A))

SARSA와 큰차이는 없지만 max항에서 볼수 있듯이 최적의 행동 가치 함수를 직접적으로 추정하는 특징이 있습니다. Q-learning에서는 다음 상태에서 가능한 모든 행동 중에서 최대 Q 값을 선택하여 Q 값을 업데이트합니다. 따라서 실제로 선택된 행동과 상관없이 최적의 행동을 기반으로 학습하는 특징이 있습니다.  그렇다면 다음의 실습 코드를 통해서 어떤 차이점이 있는지 확인해보겠습니다.

 

0,0에서 5,5로 이동하는 그리디 월드를 정의하고 이를 탐험하는 에이전트를 정의합니다.

 

NxN 그리드 월드를 탐험하는 환경정의 코드

import random
import numpy as np

class NbyNGridWorld():
    def __init__(self,N):
        self.x=0
        self.y=0
        self.N=N
    def step(self,a):
        if a==0:
            self.move_right()
        elif a==1:
            self.move_left()
        elif a==2:
            self.move_up()
        elif a==3:
            self.move_down()
        reward =-1
        done = self.is_done()
        return (self.x,self.y),reward,done
    def move_right(self):
        self.y+=1
        if self.y>self.N-1:
            self.y = self.N-1
    def move_left(self):
        self.y-=1
        if self.y<0:
            self.y = 0
    def move_down(self):
        self.x+=1
        if self.x>self.N-1:
            self.x = self.N-1
    def move_up(self):
        self.x-=1
        if self.x<0:
            self.x = 0
    def is_done(self):
        if self.x==self.N-1 and self.y== self.N-1:
            return True
        return False
    def get_state(self):
        return (self.x,self.y)
    def reset(self):
        self.x=0
        self.y=0
        return (self.x,self.y)

 

 

 

 

SARSA 에이전트 정의 코드

class QAgent():
    def __init__(self,N):
        self.n = N
        self.q_table = np.zeros((N,N,4))
        self.eps=0.9
        self.alpha = 0.01
    def select_action(self,s):
        x,y, = s
        coin = random.random()
        if coin < self.eps:
            action = random.randint(0,3)
        else:
            action_val = self.q_table[x,y,:]
            action = np.argmax(action_val)
        return action
    def update_table(self,history):
        cum_reward =0 
        for transition in history[::-1]:
            s,a,r,s_prime = transition
            x,y, = s
            nx,ny = s_prime
            a_prime = self.select_action(s_prime)
            self.q_table[x,y,a] = self.q_table[x,y,a] +self.alpha *(r+self.q_table[nx,ny,a_prime] - self.q_table[x,y,a])
            
    def anneal_eps(self):
        self.eps -=0.03
        self.eps = max(self.eps,0.1)
    def show_table(self):
        q_lst=self.q_table.tolist()
        data =np.zeros((self.n,self.n))
        for ridx in range(self.n):
            row =q_lst[ridx]
            for cidx in range(self.n):
                col = row[cidx]
                action = np.argmax(col)
                data[ridx,cidx] = action
        print(data)


환경과 상호작용하는 코드

def main(N):
    env = NbyNGridWorld(N)
    agent = QAgent(N)

    for n_epi in range(1000):
        done = False
        history =  []
        s =env.reset()
        while not done:
            a = agent.select_action(s)
            s_prime,r,done = env.step(a)
            history.append((s,a,r,s_prime))
            s = s_prime
        agent.update_table(history)
        agent.anneal_eps()
    agent.show_table()

main(5)

5x5 그리드 월드를 여행하는 SARSA에이전트의 결과

Off-policy의 경우 Q-learning이 있습니다. 

 

Q-learning 에이전트 구현 코드

class QAgent():
    def __init__(self,N):
        self.n = N
        self.q_table = np.zeros((N,N,4))
        self.eps=0.9
        self.alpha = 0.01
    def select_action(self,s):
        x,y, = s
        coin = random.random()
        if coin < self.eps:
            action = random.randint(0,3)
        else:
            action_val = self.q_table[x,y,:]
            action = np.argmax(action_val)
        return action
    def update_table(self,history):
        cum_reward =0 
        for transition in history[::-1]:
            s,a,r,s_prime = transition
            x,y, = s
            nx,ny = s_prime
            a_prime = self.select_action(s_prime)
            self.q_table[x,y,a] = self.q_table[x,y,a] +self.alpha *(r+np.max(self.q_table[nx,ny,:]) - self.q_table[x,y,a])
            
    def anneal_eps(self):
        self.eps -=0.03
        self.eps = max(self.eps,0.1)
    def show_table(self):
        q_lst=self.q_table.tolist()
        data =np.zeros((self.n,self.n))
        for ridx in range(self.n):
            row =q_lst[ridx]
            for cidx in range(self.n):
                col = row[cidx]
                action = np.argmax(col)
                data[ridx,cidx] = action
        print(data)
def main(N):
    env = NbyNGridWorld(N)
    agent = QAgent(N)

    for n_epi in range(1000):
        done = False
        history =  []
        s =env.reset()
        while not done:
            a = agent.select_action(s)
            s_prime,r,done = env.step(a)
            history.append((s,a,r,s_prime))
            s = s_prime
        agent.update_table(history)
        agent.anneal_eps()
    agent.show_table()
    
main(5)

Q-learning 알고리즘의 실행결과

둘의 이동횟수는 동일함을 알 수 있습니다. 

'인공지능 > 강화학습' 카테고리의 다른 글

[강화학습] 몬테카를로 학습  (0) 2024.05.23
Comments