Can Reinforcement Learning be Used for Classification?

8 minute read

After digging into reinforcement learning, who doesn’t wonder whether it would be useful for tasks typically reserved for supervised learning.

It’s a natural question after all.

So, I set out to find out in the best way I knew how: writing a bunch of code.

Download the jupyter notebook.

Loading Experimental Data

In an attempt to answer the question without sucking up my only gpu for weeks of training, I went for the MNIST dataset as an experimental baseline.

MNIST is a simple enough problem to be solved in only seconds, but also enough of a challenge that it should answer the question of whether or not reinforcement learning can be used to train a classifier.

If you’re not familiar with it, MNIST is a set of images of handwritten digits (0-9) in black and white. The classification problem is to determine which digit each image represents.

The images look like like this:


So let’s code this bad boy.


Our only dependencies are tensorflow and OpenAI Baselines. Let’s get those package installations out of the way:

pip install tensorflow-gpu-1.15.5
pip install git+[email protected]

Then, all the imports at once:

import time
import gym
import random
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

from baselines.ppo2 import ppo2
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv

from baselines import deepq
from baselines import bench
from baselines import logger
import tensorflow as tf

from baselines.common.tf_util import make_session

Now we’re ready to dive into the code.

Getting the Data

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train_one_hot = keras.utils.to_categorical(y_train, num_classes)
y_test_one_hot = keras.utils.to_categorical(y_test, num_classes)

With that code we use keras’ built-in utilities to download the mnist dataset and load it into memory. It is copied straight from the keras example page.

If you’ve watched any of my videos you probably know that I have a preference for PyTorch, but given that I decided to use baselines for the RL algorithms, keras made for an easier alternative for the supervised learning baseline.

Keras Baseline

Even if RL turns out to work, we won’t know if it’s any good unless we have something to compare it to.

Let’s fix that by training a classifier on the MNIST dataset with traditional supervised learning using Keras.

def keras_train(batch_size=32, epochs=2):
    model = keras.Sequential(
            layers.Dense(64, activation='relu'),
            layers.Dense(64, activation='relu'),
            layers.Dense(num_classes, activation="softmax")


    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

    start_time = time.time(), y_train_one_hot, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    end_time = time.time()

    score = model.evaluate(x_test, y_test_one_hot, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])
    print("Training Time:", end_time - start_time)


Notice that it uses an MLP with two 64 unit hidden layers. This is a pretty classical network size for RL, and I want all of the attempts to have roughly the same number of parameters.

Also, it’s got a batch size of 32, since that’s the typical batch size used in DQN.

When I run it, I get an output like this:

Test loss: 0.11103802259191871
Test accuracy: 0.9658
Training Time: 14.185736656188965

That’s 96% accuracy in only 14 seconds of training. Pretty impressive!

RL Training Environment (gym.Env)

Since OpenAI released their gym library it has become the de-facto standard for RL algorithm training environments.

Let’s build one that adapts the RL environment reward paradigm to the classification problem.

The idea is simple, every image class is a unique action that the agent can take. If it takes the action that corresponds to the correct class, we give +1 reward. Otherwise, we give 0 reward.

Also, the temporal difference error makes an assumption that actions at one step affect rewards in future steps. That is not true in our environment. So, episodes that last longer than one timestep don’t make any sense in this context.

class MnistEnv(gym.Env):
    def __init__(self, images_per_episode=1, dataset=(x_train, y_train), random=True):

        self.action_space = gym.spaces.Discrete(10)
        self.observation_space = gym.spaces.Box(low=0, high=1,
                                                shape=(28, 28, 1),

        self.images_per_episode = images_per_episode
        self.step_count = 0

        self.x, self.y = dataset
        self.random = random
        self.dataset_idx = 0

    def step(self, action):
        done = False
        reward = int(action == self.expected_action)

        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {}

    def reset(self):
        self.step_count = 0

        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, len(self.x) - 1)
            self.expected_action = int(self.y[next_obs_idx])
            obs = self.x[next_obs_idx]

            obs = self.x[self.dataset_idx]
            self.expected_action = int(self.y[self.dataset_idx])

            self.dataset_idx += 1
            if self.dataset_idx >= len(self.x):
                raise StopIteration()

        return obs

The resulting code for the gym Env is pretty straightforward.

The only thing to note is that we can swap out the dataset and random parameters to the __init__ function to turn the gym into an evaluation bench on the test set.

The Moment of Truth

Now, we find out once and for all. Can RL be used for classification?

First, we’ll train with a Dueling Deep Q-Network.

def mnist_dqn():
    logger.configure(dir='./logs/mnist_dqn', format_strs=['stdout', 'tensorboard'])
    env = MnistEnv(images_per_episode=1)
    env = bench.Monitor(env, logger.get_dir())

    model = deepq.learn(

    return model

start_time = time.time()
dqn_model = mnist_dqn()
print("DQN Training Time:", time.time() - start_time)

This is pretty standard stuff, nothing fancy going on. We’re using the DQN implementation from OpenAI Baselines.

Batch size is 32, same as the keras model, and the total timesteps are 120,000, which is two times the number of samples in the training set. This simulates 2 epochs like we are using in the keras model.

Compared to the supervised baseline, there is an architectural difference. The dueling portion of the algorithm breaks the final layer into two separate 32 unit layers instead. Those two heads converge separately into a single output, and an output for each action.

Ultimately, it means there will be a few less parameters, but we’re still in the same ballpark.

The final output after training looks like this:

| % time spent exploring  | 1        |
| episodes                | 1.2e+05  |
| mean 100 episode reward | 1        |
| steps                   | 1.2e+05  |

DQN Training Time: 461.527117729187

It took a lot longer (more than 30x) than the supervised baseline, but it looks like it was hitting 100% accuracy on the training set!

Let’s run an evaluation and see how it holds up on the test set.

def mnist_dqn_eval(dqn_model):
    attempts, correct = 0,0

    env = MnistEnv(images_per_episode=1, dataset=(x_test, y_test), random=False)

        while True:
            obs, done = env.reset(), False
            while not done:
                obs, rew, done, _ = env.step(dqn_model(obs[None])[0])

                attempts += 1
                if rew > 0:
                    correct += 1

    except StopIteration:
        print('validation done...')
        print('Accuracy: {0}%'.format((float(correct) / attempts) * 100))


Aaaaaannnndddd… the results are:

Accuracy: 93.47869573914784%

93.4% Accuracy!

Well I guess that answers it, reinforcement learning can definitely be used as a classifier.

That is, as long as you’re willing to wait 30x the amount of time to train it with RL.

Can we do better though? Maybe it’s just the algorithm. DQN is, after all, not the king of the RL algorithms.

Testing with the King

To find out if we can do better, we’re going to do another experiment with the king of RL algorithms.

… drumroll …

Proximal Policy Optimization (PPO), introduced by John Schulman in 2017, has held the king algorithm spot for a while now.

It’s an algorithm that is flexible enough to apply to many problem types, and robust enough to not require much hyperparameter tuning.

In fact, it turned out to be so good, it surprised everyone when OpenAI used it to train DOTA 2 bots that crushed us mere mortals in professional play.

So, let’s see how it performs as a lowly classifier.

def mnist_ppo():
    logger.configure(dir='./logs/mnist_ppo', format_strs=['stdout', 'tensorboard'])
    env = DummyVecEnv([lambda: bench.Monitor(MnistEnv(images_per_episode=1), logger.get_dir())])

    model = ppo2.learn(

    return model

start_time = time.time()
ppo_model = mnist_ppo()
print("PPO Training Time:", time.time() - start_time)

Same thing as with DQN, this is a standard setup using the baselines implementation of PPO with our environment.

Architecturally, the PPO model adds a value head with a single output after the final MLP layer. Ultimately, it means a few more parameters compared to the supervised baseline.

After training, output looks like this:

| eplenmean               | 1        |
| eprewmean               | 0.95     |
| fps                     | 177      |
| loss/approxkl           | 0.131    |
| loss/clipfrac           | 0.0625   |
| loss/policy_entropy     | 0.0253   |
| loss/policy_loss        | -0.0292  |
| loss/value_loss         | 0.0271   |
| misc/explained_variance | 0.115    |
| misc/nupdates           | 3.75e+03 |
| misc/serial_timesteps   | 1.2e+05  |
| misc/time_elapsed       | 638      |
| misc/total_timesteps    | 1.2e+05  |
PPO Training Time: 638.657053232193

WHOA! Over 10 minutes of training time, and it’s not even hitting 100% on the training set.

This is not looking good, but let’s see how it does on the training set.

def mnist_ppo_eval(ppo_model):
    attempts, correct = 0,0

    env = DummyVecEnv([lambda: MnistEnv(images_per_episode=1, dataset=(x_test, y_test), random=False)])

        while True:
            obs, done = env.reset(), [False]
            while not done[0]:
                obs, rew, done, _ = env.step(ppo_model.step(obs[None])[0])

                attempts += 1
                if rew[0] > 0:
                    correct += 1

    except StopIteration:
        print('validation done...')
        print('Accuracy: {0}%'.format((float(correct) / attempts) * 100))


The eval function is almost copy/paste of the DQN eval function, but accounts for the batched env.

The output:

Accuracy: 95.1995199519952%

95% accuracy! Looks like PPO does have an edge over DQN even though it takes 150% longer to train.

Still though, we’re about 40x the time it took to train using supervised learning and still not at a higher accuracy.

What’s important though is that we answered our question!

Reinforcement Learning CAN be used to train a classifier.

BUT only a maniac would wait the 40x longer time!

If you enjoyed this article, visit my youtube channel where I discuss various AI topics with a focus on RL.