Code for Proximal Policy Optimization Algorithms

Implementation in 100 lines of Python · Proximal Policy Optimization Algorithms
View on GitHub →
Abstract (original paper)

We propose a new family of policy gradient methods for reinforcement learning, which alternate between sampling data through interaction with the environment, and optimizing a "surrogate" objective function using stochastic gradient ascent. Whereas standard policy gradient methods perform one gradient update per data sample, we propose a novel objective function that enables multiple epochs of minibatch updates. The new methods, which we call proximal policy optimization (PPO), have some of the benefits of trust region policy optimization (TRPO), but they are much simpler to implement, more general, and have better sample complexity (empirically). Our experiments test PPO on a collection of benchmark tasks, including simulated robotic locomotion and Atari game playing, and we show that PPO outperforms other online policy gradient methods, and overall strikes a favorable balance between sample complexity, simplicity, and wall-time.

Source: Proximal Policy Optimization Algorithms (2017-08-28). See: paper link.

Code

Proximal Policy Optimization Algorithms in 100 lines (Python)

import gym
import torch
import random
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from torch.utils.data import DataLoader, TensorDataset


class ActorCritic(nn.Module):
    def __init__(self, nb_actions):
        super().__init__()
        self.head = nn.Sequential(nn.Conv2d(4, 16, 8, stride=4), nn.Tanh(),
                                  nn.Conv2d(16, 32, 4, stride=2), nn.Tanh(),
                                  nn.Flatten(), nn.Linear(2592, 256), nn.Tanh(),)
        self.actor = nn.Sequential(nn.Linear(256, nb_actions))
        self.critic = nn.Sequential(nn.Linear(256, 1),)

    def forward(self, x):
        h = self.head(x)
        return self.actor(h), self.critic(h)


class Environments():
    def __init__(self, nb_actor):
        self.envs = [self.get_env() for _ in range(nb_actor)]
        self.observations = [None for _ in range(nb_actor)]
        self.current_life = [None for _ in range(nb_actor)]
        self.done = [False for _ in range(nb_actor)]
        self.total_rewards = [0 for _ in range(nb_actor)]
        self.nb_actor = nb_actor

        for env_id in range(nb_actor):
            self.reset_env(env_id)

    def len(self):
        return self.nb_actor

    def reset_env(self, env_id):
        self.total_rewards[env_id] = 0
        self.envs[env_id].reset()

        for _ in range(random.randint(1, 30)):  # Noop and fire to reset environment
            self.observations[env_id], reward, _, info = self.envs[env_id].step(1)
            self.total_rewards[env_id] += reward
            self.current_life[env_id] = info['lives']

    def step(self, env_id, action):
        next_obs, reward, dead, info = self.envs[env_id].step(action)
        done = True if (info['lives'] < self.current_life[env_id]) else False
        self.done[done] = done
        self.total_rewards[env_id] += reward
        self.current_life[env_id] = info['lives']
        self.observations[env_id] = next_obs
        return next_obs, reward, dead, done, info

    def get_env(self):
        env = gym.make("BreakoutNoFrameskip-v4")
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        env = MaxAndSkipEnv(env, skip=4)
        return env


def PPO(envs, T=128, K=3, batch_size=32*8, gamma=0.99, device='cuda', gae_parameter=0.95,
        vf_coeff_c1=1, ent_coef_c2=0.01, nb_iterations=40_000):

    optimizer = torch.optim.Adam(actorcritic.parameters(), lr=2.5e-4)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1., end_factor=0.0, total_iters=nb_iterations)

    max_reward = 0
    total_rewards = [[] for _ in range(envs.len())]
    smoothed_rewards = [[] for _ in range(envs.len())]

    for iteration in tqdm(range(nb_iterations)):
        advantages = torch.zeros((envs.len(), T), dtype=torch.float32, device=device)
        buffer_states = torch.zeros((envs.len(), T, 4, 84, 84), dtype=torch.float32, device=device)
        buffer_actions = torch.zeros((envs.len(), T), dtype=torch.long, device=device)
        buffer_logprobs = torch.zeros((envs.len(), T), dtype=torch.float32, device=device)
        buffer_state_values = torch.zeros((envs.len(), T+1), dtype=torch.float32, device=device)
        buffer_rewards = torch.zeros((envs.len(), T), dtype=torch.float32, device=device)
        buffer_is_terminal = torch.zeros((envs.len(), T), dtype=torch.float16, device=device)

        for env_id in range(envs.len()):
            with torch.no_grad():
                for t in range(T):  # Run policy θ_old in environment for T timesteps
                    obs = torch.from_numpy(envs.observations[env_id] / 255.).unsqueeze(0).float().to(device)
                    logits, value = actorcritic(obs)
                    logits, value = logits.squeeze(0), value.squeeze(0)
                    m = torch.distributions.categorical.Categorical(logits=logits)

                    if envs.done[env_id]:  # Fire to reset env
                        action = torch.tensor([1]).to(device)
                    else:
                        action = m.sample()

                    log_prob = m.log_prob(action)
                    _, reward, dead, done, _ = envs.step(env_id, action)
                    reward = np.sign(reward)  # Reward clipping

                    buffer_states[env_id, t] = obs
                    buffer_actions[env_id, t] = torch.tensor([action]).to(device)
                    buffer_logprobs[env_id, t] = log_prob
                    buffer_state_values[env_id, t] = value
                    buffer_rewards[env_id, t] = reward
                    buffer_is_terminal[env_id, t] = done

                    if dead:
                        if envs.total_rewards[env_id] > max_reward:
                            max_reward = envs.total_rewards[env_id]
                            torch.save(actorcritic.cpu(), f"actorcritic_{max_reward}")
                            actorcritic.to(device)

                        total_rewards[env_id].append(envs.total_rewards[env_id])
                        envs.reset_env(env_id)

                buffer_state_values[env_id, T] = actorcritic(
                    torch.from_numpy(envs.observations[env_id] / 255.).unsqueeze(0).float().to(device))[1].squeeze(0)

                # Compute advantage estimates A^1; ... ; A^T
                for t in range(T-1, -1, -1):
                    next_non_terminal = 1.0 - buffer_is_terminal[env_id, t]
                    delta_t = buffer_rewards[env_id, t] + gamma * buffer_state_values[
                        env_id, t+1] * next_non_terminal - buffer_state_values[env_id, t]
                    if t == (T-1):
                        A_t = delta_t
                    else:
                        A_t = delta_t + gamma * gae_parameter * advantages[env_id, t+1] * next_non_terminal
                    advantages[env_id, t] = A_t

        if (iteration % 400 == 0) and iteration > 0:
            for env_id in range(len(envs)):
                smoothed_rewards[env_id].append(np.mean(total_rewards[env_id]))
                plt.plot(smoothed_rewards[env_id])
            total_rewards = [[] for _ in range(len(envs))]
            plt.title("Average Reward on Breakout")
            plt.xlabel("Training Epochs")
            plt.ylabel("Average Reward per Episode")
            plt.savefig('average_reward_on_breakout.png')
            plt.close()

        for epoch in range(K):
            advantages_data_loader = DataLoader(
                TensorDataset(advantages.reshape(advantages.shape[0] * advantages.shape[1]),
                              buffer_states.reshape(-1, buffer_states.shape[2], buffer_states.shape[3],
                                                    buffer_states.shape[4]),
                              buffer_actions.reshape(-1),
                              buffer_logprobs.reshape(-1),
                              buffer_state_values[:, :T].reshape(-1),),
                batch_size=batch_size, shuffle=True,)

            for batch_advantages in advantages_data_loader:
                b_adv, obs, action_that_was_taken, old_log_prob, old_state_values = batch_advantages

                logits, value = actorcritic(obs)
                logits, value = logits.squeeze(0), value.squeeze(-1)
                m = torch.distributions.categorical.Categorical(logits=logits)
                log_prob = m.log_prob(action_that_was_taken)
                ratio = torch.exp(log_prob - old_log_prob)
                returns = b_adv + old_state_values

                # clipped surrogate objective
                policy_loss_1 = b_adv * ratio
                alpha = 1. - iteration / nb_iterations
                clip_range = 0.1 * alpha
                policy_loss_2 = b_adv * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

                # clipped value loss; see https://arxiv.org/pdf/2005.12729
                value_loss1 = F.mse_loss(returns, value, reduction='none')
                value_loss2 = F.mse_loss(returns, torch.clamp(value, value - clip_range, value + clip_range),
                                         reduction='none')
                value_loss = torch.max(value_loss1, value_loss2).mean()

                loss = policy_loss + ent_coef_c2 * -(m.entropy()).mean() + vf_coeff_c1 * value_loss

                # θ_old <- θ
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(actorcritic.parameters(), 0.5)
                optimizer.step()
        scheduler.step()


if __name__ == "__main__":
    device = 'cuda'
    nb_actor = 8
    envs = Environments(nb_actor)
    actorcritic = ActorCritic(envs.envs[0].action_space.n).to(device)
    PPO(envs, device=device)

python implementation Proximal Policy Optimization Algorithms in 100 lines