관심있는 주제/강화학습

torch tensor concat 하는 방법

Lynn123 2023. 1. 10. 22:59
반응형

Background

강화학습을 하다 보면 buffer 를 구성해야 할 일이 많다.

 

강화학습을 진행하면서 나온 trajectory를 buffer에 저장했다가 update 시에 꺼내서 사용해야 하기 때문에 buffer를 구성해서 사용하는데,

이때 늘 episode 길이가 동일해서 buffer내 rewards, actions, obs 등의 size 가 항상 동일하거나 rollout 으로 늘 동일한 step 수 만큼 저장해서 동일한 size 의 항목들만 다룬다면 처음부터 고정된 size 의 buffer를 사용하면 되지만 그렇지 않은 경우엔 비어있는 buf를 만들어서 append (concat 등) 을 해야 하는 경우가 있다.

 

아래는 buffer에서 사이즈를 고정으로 사용할 때 예시이다. 아래와 같이 torch.zeros로 고정된 size의 tensor를 만든 후 insert 하는 구조가 나올 수 있다.

def __init__(self, num_steps, num_processes, obs_shape, gamma):
    self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
    self.rewards = torch.zeros(num_steps, num_processes, 1)
    self.returns = torch.zeros(num_steps + 1, num_processes, 1)
    self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
    self.actions = torch.zeros(num_steps, num_processes, 1).long()
    self.masks = torch.ones(num_steps + 1, num_processes, 1)

    self.num_steps = num_steps
    self.gamma = gamma
    self.step = 0
def insert(self, obs, actions, action_log_probs, rewards, masks):
    self.obs[self.step + 1].copy_(obs)
    self.actions[self.step].copy_(actions)
    self.action_log_probs[self.step].copy_(action_log_probs)
    self.rewards[self.step].copy_(rewards)
    self.step = (self.step + 1) % self.num_steps
    self.masks[self.step + 1].copy_(masks)

 

고정 size 를 이용하지 않는 경우

고정된 size를 이용하지 않기 때문에 처음부터 buffer 내 값들을 채워 놓고 시작할 수 없을 땐 아래처럼 비어있는 tensor를 만들고 concat 하는 방식으로 진행하면 된다.

self.obs = torch.zeros(1, self.num_processes, *self.obs_shape)
self.rewards =  torch.tensor([])
self.returns = torch.tensor([])
self.action_log_probs = torch.tensor([])
self.actions = torch.tensor([])
self.masks = torch.ones(1, self.num_processes, 1)
self.step = 0
def insert(self, obs, actions, action_log_probs, rewards, masks):
    self.obs = torch.concat([self.obs, obs[None]])
    self.actions = torch.concat([self.actions, actions[None]]) if self.actions.nelement() else actions[None]
    self.action_log_probs = torch.concat(
        [self.action_log_probs, action_log_probs[None]]
        ) if self.action_log_probs.nelement() else action_log_probs[None]
    self.rewards = torch.concat([self.rewards, rewards[None]]) if self.rewards.nelement() else rewards[None]
    self.masks = torch.concat([self.masks, masks[None]])
    self.step =+ 1

 

 

insert 하는 방법을 하나만 가져와서 설명한다면,

self.rewards = torch.concat([self.rewards, rewards[None]]) if self.rewards.nelement() else rewards[None]

 

먼저 해당 텐서가 비어있는 값인지 (reset 된 초기 텐서인지) 아닌지 확인하여 저장할 reward tensor를 insert 해준다.

이때 비어 있는 텐서에 바로 concat 하면 차원을 맞춰주기 까다로울 수 있으므로 다는 비어있는지 아닌지 판단하고 비어있다면 rewars 자체를 self.rewards 로 덮어 쓰게 해줬다.

 

이때 [None]의 역할은 차원을 하나 늘려준다고 생각하면 된다. 예를 들어, [1] 이라는 reward가 들어온다면, buffer 에 저장하는 형태는 [ [1] ] 이 될 것이고 뒤이어 들어오는 rewards 까지 저장한다면 [ [1], [1], [0], [2], ... [0] ] 이런 형태가 될텐데 이를 맞춰주기 위한 장치라고 이해하면 된다.

 

 

Numpy 를 이용하는 경우

buffer 코드를 다르게 짜긴 했는데, 딱 사용하는 부분만 보면 아래와 같다.

self.obs_buf = np.array([])
self.act_buf = np.array([])  # action
self.adv_buf = np.array([])  # advantage
self.rew_buf = np.array([])  # reward
self.ret_buf = np.array([])  # target value
self.val_buf = np.array([])  # value
self.done_buf = np.array([])  # done
self.logp_buf = np.array([])  # log probability
self.obs_buf = (
     np.vstack([self.obs_buf, buffer.obs_buf]) if self.obs_buf.size else buffer.obs_buf
     )
self.act_buf = np.hstack([self.act_buf, buffer.act_buf]) if self.act_buf.size else buffer.act_buf
self.rew_buf = np.hstack([self.rew_buf, buffer.rew_buf]) if self.rew_buf.size else buffer.rew_buf
self.val_buf = np.hstack([self.val_buf, buffer.val_buf]) if self.val_buf.size else buffer.val_buf
self.logp_buf = np.hstack([self.logp_buf, buffer.logp_buf]) if self.logp_buf.size else buffer.logp_buf
self.done_buf = np.hstack([self.done_buf, buffer.done_buf]) if self.done_buf.size else buffer.done_buf
self.adv_buf = np.hstack([self.adv_buf, buffer.adv_buf]) if self.adv_buf.size else buffer.adv_buf
self.ret_buf = np.hstack([self.ret_buf, buffer.ret_buf]) if self.ret_buf.size else buffer.ret_buf

 

반응형