Using RL to Model Cognitive Tasks
Contents
Using RL to Model Cognitive Tasks¶
By Neurmatch Academy
Content creators: Morteza Ansarinia, Yamil Vidal
Production editor: Spiros Chavlis
Objective¶
This project aims to use behavioral data to train an agent and then use the agent to investigate data produced by human subjects. Having a computational agent that mimics humans in such tests, we will be able to compare its mechanics with human data.
In another conception, we could fit an agent that learns many cognitive tasks that require abstract-level constructs such as executive functions. This is a multi-task control problem.
Setup¶
Install dependencies¶
Install dependencies¶
Install dependencies¶
Install dependencies¶
Install dependencies¶
# @title Install dependencies
!pip install --upgrade pip setuptools wheel --quiet
!pip install dm-acme[jax,tensorflow] --quiet
!pip install dm-sonnet --quiet
!pip install trfl --quiet
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: dm-acme 0.4.0 does not provide the extra 'tensorflow'
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
# Imports
import time
import numpy as np
import pandas as pd
import sonnet as snt
import seaborn as sns
import matplotlib.pyplot as plt
import dm_env
import acme
from acme import specs
from acme import wrappers
from acme import EnvironmentLoop
from acme.agents.tf import dqn
from acme.utils import loggers
Background¶
Cognitive scientists use standard lab tests to tap into specific processes in the brain and behavior. Some examples of those tests are Stroop, N-back, Digit Span, TMT (Trail making tests), and WCST (Wisconsin Card Sorting Tests).
Despite an extensive body of research that explains human performance using descriptive what-models, we still need a more sophisticated approach to gain a better understanding of the underlying processes (i.e., a how-model).
Interestingly, many of such tests can be thought of as a continuous stream of stimuli and corresponding actions, that is in consonant with the RL formulation. In fact, RL itself is in part motivated by how the brain enables goal-directed behaviors using reward systems, making it a good choice to explain human performance.
One behavioral test example would be the N-back task.
In the N-back, participants view a sequence of stimuli, one by one, and are asked to categorize each stimulus as being either match or non-match. Stimuli are usually numbers, and feedback is given at both timestep and trajectory levels.
The agent is rewarded when its response matches the stimulus that was shown N steps back in the episode. A simpler version of the N-back uses two-choice action schema, that is match vs non-match. Once the present stimulus matches the one presented N step back, then the agent is expected to respond to it as being a
match
.
Given a trained RL agent, we then find correlates of its fitted parameters with the brain mechanisms. The most straightforward composition could be the correlation of model parameters with the brain activities.
Datasets¶
HCP WM task (NMA-CN HCP notebooks)
Any dataset that used cognitive tests would work. Question: limit to behavioral data vs fMRI? Question: Which stimuli and actions to use? classic tests can be modeled using 1) bounded symbolic stimuli/actions (e.g., A, B, C), but more sophisticated one would require texts or images (e.g., face vs neutral images in social stroop dataset) The HCP dataset from NMA-CN contains behavioral and imaging data for 7 cognitive tests including various versions of N-back.
N-back task¶
In the N-back task, participants view a sequence of stimuli, one per time, and are asked to categorize each stimulus as being either match or non-match. Stimuli are usually numbers, and feedbacks are given at both timestep and trajectory levels.
In a typical neuro setup, both accuracy and response time are measured, but here, for the sake of brevity, we focus only on accuracy of responses.
Cognitive Tests Environment¶
First we develop an environment in that agents perform a cognitive test, here the N-back.
Human dataset¶
We need a dataset of human perfoming a N-back test, with the following features:
participant_id
: following the BIDS format, it contains a unique identifier for each participant.trial_index
: same astime_step
.stimulus
: same asobservation
.response
: same asaction
, recorded response by the human subject.expected_response
: correct response.is_correct
: same asreward
, whether the human subject responded correctly.response_time
: won’t be used here.
Here we generate a mock dataset with those features, but remember to replace this with real human data.
def generate_mock_nback_dataset(N=2,
n_participants=10,
n_trials=32,
stimulus_choices=list('ABCDEF'),
response_choices=['match', 'non-match']):
"""Generate a mock dataset for the N-back task."""
n_rows = n_participants * n_trials
participant_ids = sorted([f'sub-{pid}' for pid in range(1, n_participants + 1)] * n_trials)
trial_indices = list(range(1, n_trials + 1)) * n_participants
stimulus_sequence = np.random.choice(stimulus_choices, n_rows)
responses = np.random.choice(response_choices, n_rows)
response_times = np.random.exponential(size=n_rows)
df = pd.DataFrame({
'participant_id': participant_ids,
'trial_index': trial_indices,
'stimulus': stimulus_sequence,
'response': responses,
'response_time': response_times
})
# mark matchig stimuli
_nback_stim = df['stimulus'].shift(N)
df['expected_response'] = (df['stimulus'] == _nback_stim).map({True: 'match', False: 'non-match'})
df['is_correct'] = (df['response'] == df['expected_response'])
# we don't care about burn-in trials (trial < N)
df.loc[df['trial_index'] <= N, 'is_correct'] = True
df.loc[df['trial_index'] <= N, ['response', 'response_time', 'expected_response']] = None
return df
# ========
# now generate the actual data with the provided function and plot some of its features
mock_nback_data = generate_mock_nback_dataset()
mock_nback_data['is_correct'] = mock_nback_data['is_correct'].astype(int)
sns.displot(data=mock_nback_data, x='response_time')
plt.suptitle('response time distribution of the mock N-back dataset', y=1.01)
plt.show()
sns.displot(data=mock_nback_data, x='is_correct')
plt.suptitle('Accuracy distribution of the mock N-back dataset', y=1.06)
plt.show()
sns.barplot(data=mock_nback_data, y='is_correct', x='participant_id')
plt.suptitle('Accuracy distribution of the mock N-back dataset', y=1.06)
plt.show()
mock_nback_data.head()
participant_id | trial_index | stimulus | response | response_time | expected_response | is_correct | |
---|---|---|---|---|---|---|---|
0 | sub-1 | 1 | B | None | NaN | None | 1 |
1 | sub-1 | 2 | B | None | NaN | None | 1 |
2 | sub-1 | 3 | D | non-match | 1.637974 | non-match | 1 |
3 | sub-1 | 4 | C | non-match | 0.096110 | non-match | 1 |
4 | sub-1 | 5 | B | non-match | 0.703303 | non-match | 1 |
Implementation scheme¶
Environment¶
The following cell implments N-back envinronment, that we later use to train a RL agent on human data. It is capable of performing two kinds of simulation:
rewards the agent once the action was correct (i.e., a normative model of the environment).
receives human data (or mock data if you prefer), and returns what participants performed as the observation. This is more useful for preference-based RL.
class NBack(dm_env.Environment):
ACTIONS = ['match', 'non-match']
def __init__(self,
N=2,
episode_steps=32,
stimuli_choices=list('ABCDEF'),
human_data=None,
seed=1,
):
"""
Args:
N: Number of steps to look back for the matched stimuli. Defaults to 2 (as in 2-back).
episode_steps
stimuli_choices
human_data
seed
"""
self.N = N
self.episode_steps = episode_steps
self.stimuli_choices = stimuli_choices
self.stimuli = np.empty(shape=episode_steps) # will be filled in the `reset()`
self._reset_next_step = True
# whether mimic humans or reward the agent once it responds optimally.
if human_data is None:
self._imitate_human = False
self.human_data = None
self.human_subject_data = None
else:
self._imitate_human = True
self.human_data = human_data
self.human_subject_data = None
self._action_history = []
def reset(self):
self._reset_next_step = False
self._current_step = 0
self._action_history.clear()
# generate a random sequence instead of relying on human data
if self.human_data is None:
# self.stimuli = np.random.choice(self.stimuli_choices, self.episode_steps)
# FIXME This is a fix for acme & reverb issue with string observation. Agent should be able to handle strings
self.stimuli = np.random.choice(len(self.stimuli_choices), self.episode_steps).astype(np.float32)
else:
# randomly choose a subject from the human data and follow her trials and responses.
# FIXME should we always use one specific human subject or randomly select one in each episode?
self.human_subject_data = self.human_data.query('participant_id == participant_id.sample().iloc[0]',
engine='python').sort_values('trial_index')
self.stimuli = self.human_subject_data['stimulus'].to_list()
self.stimuli = np.array([ord(s) - ord('A') + 1 for s in self.stimuli]).astype(np.float32)
return dm_env.restart(self._observation())
def _episode_return(self):
if self._imitate_human:
return np.mean(self.human_subject_data['response'] == self._action_history)
else:
return 0.0
def step(self, action: int):
if self._reset_next_step:
return self.reset()
agent_action = NBack.ACTIONS[action]
if self._imitate_human:
# if it was the same action as the human subject, then reward the agent
human_action = self.human_subject_data['response'].iloc[self._current_step]
step_reward = 0. if (agent_action == human_action) else -1.
else:
# assume the agent is rationale and doesn't want to reproduce human, reward once the response it correct
expected_action = 'match' if (self.stimuli[self._current_step] == self.stimuli[self._current_step - self.N]) else 'non-match'
step_reward = 0. if (agent_action == expected_action) else -1.
self._action_history.append(agent_action)
self._current_step += 1
# Check for termination.
if self._current_step == self.stimuli.shape[0]:
self._reset_next_step = True
# we are using the mean of total time step rewards as the episode return
return dm_env.termination(reward=self._episode_return(),
observation=self._observation())
else:
return dm_env.transition(reward=step_reward,
observation=self._observation())
def observation_spec(self):
return dm_env.specs.BoundedArray(
shape=self.stimuli.shape,
dtype=self.stimuli.dtype,
name='nback_stimuli', minimum=0, maximum=len(self.stimuli_choices) + 1)
def action_spec(self):
return dm_env.specs.DiscreteArray(
num_values=len(NBack.ACTIONS),
dtype=np.int32,
name='action')
def _observation(self):
# agent observes only the current trial
# obs = self.stimuli[self._current_step - 1]
# agents observe stimuli up to the current trial
obs = self.stimuli[:self._current_step+1].copy()
obs = np.pad(obs,(0, len(self.stimuli) - len(obs)))
return obs
def plot_state(self):
"""Display current state of the environment.
Note: `M` mean `match`, and `.` is a `non-match`.
"""
stimuli = self.stimuli[:self._current_step - 1]
actions = ['M' if a=='match' else '.' for a in self._action_history[:self._current_step - 1]]
return HTML(
f'<b>Environment ({self.N}-back):</b><br />'
f'<pre><b>Stimuli:</b> {"".join(map(str,map(int,stimuli)))}</pre>'
f'<pre><b>Actions:</b> {"".join(actions)}</pre>'
)
@staticmethod
def create_environment():
"""Utility function to create a N-back environment and its spec."""
# Make sure the environment outputs single-precision floats.
environment = wrappers.SinglePrecisionWrapper(NBack())
# Grab the spec of the environment.
environment_spec = specs.make_environment_spec(environment)
return environment, environment_spec
Define a random agent¶
For more information you can refer to NMA-DL W3D2 Basic Reinforcement learning.
class RandomAgent(acme.Actor):
def __init__(self, environment_spec):
"""Gets the number of available actions from the environment spec."""
self._num_actions = environment_spec.actions.num_values
def select_action(self, observation):
"""Selects an action uniformly at random."""
action = np.random.randint(self._num_actions)
return action
def observe_first(self, timestep):
"""Does not record as the RandomAgent has no use for data."""
pass
def observe(self, action, next_timestep):
"""Does not record as the RandomAgent has no use for data."""
pass
def update(self):
"""Does not update as the RandomAgent does not learn from data."""
pass
Initialize the environment and the agent¶
env, env_spec = NBack.create_environment()
agent = RandomAgent(env_spec)
print('actions:\n', env_spec.actions)
print('observations:\n', env_spec.observations)
print('rewards:\n', env_spec.rewards)
actions:
DiscreteArray(shape=(), dtype=int32, name=action, minimum=0, maximum=1, num_values=2)
observations:
BoundedArray(shape=(32,), dtype=dtype('float32'), name='nback_stimuli', minimum=0.0, maximum=7.0)
rewards:
Array(shape=(), dtype=dtype('float32'), name='reward')
Run the loop¶
For more details, see NMA-DL W3D2.
# fitting parameters
n_episodes = 1_000
n_total_steps = 0
log_loss = False
n_steps = n_episodes * 32
all_returns = []
# main loop
for episode in range(n_episodes):
episode_steps = 0
episode_return = 0
episode_loss = 0
start_time = time.time()
timestep = env.reset()
# Make the first observation.
agent.observe_first(timestep)
# Run an episode
while not timestep.last():
# DEBUG
# print(timestep)
# Generate an action from the agent's policy and step the environment.
action = agent.select_action(timestep.observation)
timestep = env.step(action)
# Have the agent observe the timestep and let the agent update itself.
agent.observe(action, next_timestep=timestep)
agent.update()
# Book-keeping.
episode_steps += 1
n_total_steps += 1
episode_return += timestep.reward
if log_loss:
episode_loss += agent.last_loss
if n_steps is not None and n_total_steps >= n_steps:
break
# Collect the results and combine with counts.
steps_per_second = episode_steps / (time.time() - start_time)
result = {
'episode': episode,
'episode_length': episode_steps,
'episode_return': episode_return,
}
if log_loss:
result['loss_avg'] = episode_loss/episode_steps
all_returns.append(episode_return)
display(env.plot_state())
# Log the given results.
print(result)
if n_steps is not None and n_total_steps >= n_steps:
break
clear_output()
# Histogram of all returns
plt.figure()
sns.histplot(all_returns, stat="density", kde=True, bins=12)
plt.xlabel('Return [a.u.]')
plt.ylabel('Density')
plt.show()
Note: You can simplify the environment loop using DeepMind Acme.
# init a new N-back environment
env, env_spec = NBack.create_environment()
# DEBUG fake testing environment.
# Uncomment this to debug your agent without using the N-back environment.
# env = fakes.DiscreteEnvironment(
# num_actions=2,
# num_observations=1000,
# obs_dtype=np.float32,
# episode_length=32)
# env_spec = specs.make_environment_spec(env)
def dqn_make_network(action_spec: specs.DiscreteArray) -> snt.Module:
return snt.Sequential([
snt.Flatten(),
snt.nets.MLP([50, 50, action_spec.num_values]),
])
# construct a DQN agent
agent = dqn.DQN(
environment_spec=env_spec,
network=dqn_make_network(env_spec.actions),
epsilon=[0.5],
logger=loggers.InMemoryLogger(),
checkpoint=False,
)
Now, we run the environment loop with the DQN agent and print the training log.
# training loop
loop = EnvironmentLoop(env, agent, logger=loggers.InMemoryLogger())
loop.run(n_episodes)
# print logs
logs = pd.DataFrame(loop._logger._data)
logs.tail()
episode_length | episode_return | steps_per_second | episodes | steps | |
---|---|---|---|---|---|
995 | 32 | -8.0 | 364.026677 | 996 | 31872 |
996 | 32 | -10.0 | 345.859582 | 997 | 31904 |
997 | 32 | -10.0 | 370.050615 | 998 | 31936 |
998 | 32 | -16.0 | 379.613671 | 999 | 31968 |
999 | 32 | -16.0 | 367.838720 | 1000 | 32000 |