Submission Template

Open In Colab

Submission Template#

This notebook provides a suggested starter template for completing the SAM model assignment.

You should submit your assignment by uploading your completed notebook to Canvas. Please ensure that your notebook runs without errors in Google Colaboratory.

Imports

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import requests
import os
from tempfile import NamedTemporaryFile

Download dataset and store as sequence objects containing the presented and recalled items for each trial. The sequences are stored in nested dictionaries in the form

dict[list length][presentation rate]
class item(object):
    idx = 1

    def __init__(self, val=None):
        if val is None:
            self.id = item.idx
            item.idx += 1
        else:
            self.id = val


class sequence(object):
    def __init__(self, items):
        self.items = items


def load_recall_data():
    base_url = "https://raw.githubusercontent.com/ContextLab/memory-models-course/refs/heads/main/content/assignments/Assignment_2%3ASearch_of_Associative_Memory_Model/Murd62%20data/"
    filenames = ["fr10-2.txt", "fr15-2.txt", "fr20-1.txt", "fr20-2.txt", "fr30-1.txt", "fr40-1.txt"]

    presented = {}
    recalled = {}

    for filename in filenames:
        list_len, pres_rate = map(int, filename.replace(".txt", "").replace("fr", "").split("-"))
        if list_len not in presented:
            presented[list_len] = {}
            recalled[list_len] = {}
        if pres_rate not in presented[list_len]:
            presented[list_len][pres_rate] = []
            recalled[list_len][pres_rate] = []

        # Download the file
        url = base_url + filename
        response = requests.get(url)
        response.raise_for_status()
        lines = response.text.strip().split("\n")

        for line in lines:
            recall_ids = [int(x) for x in line.strip().split() if int(x) != 88]
            recall_seq = sequence([item(val) for val in recall_ids])
            presented_seq = sequence([item(val) for val in range(1, list_len + 1)])

            presented[list_len][pres_rate].append(presented_seq)
            recalled[list_len][pres_rate].append(recall_seq)

    return presented, recalled

presented, recalled = load_recall_data()

Basic skeleton for the SAM model

class STS(object):
  def __init__(self, r, q, s_f, s_b, max_items=None, lts=None):
    self.r = r
    self.q = q
    if lts is None:
      self.LTS = LTS(max_items, s_f, s_b)
    else:
      self.LTS = lts
    self.items = []
    self.entry_times = np.zeros(1, r, dtype=np.int32)

  def present(self, x):
    # p(displacement) = q(q - q)^(i - 1) / (1 - (1 - q))^r
    #     i: relative age of item
    #  q, r: model params
    #
    # check current capacity; if available capacity, add item to STS.  else displace items.
    pass


class LTS(object):
  def __init__(self, max_items, s_f, s_b):
    self.max_items = max_items
    self.s_f = s_f
    self.s_b = s_b
    self.S = np.zeros((max_items, max_items), dtype=np.float32)
    self.context = np.zeros(max_items, dtype=np.float32)
    self.previous_recall = None

  def update(self, items):
    # update self.S and self.context
    pass

class SAM(object):
  def __init__(self, W_c, W_e, M_1, M_2, r, q, max_items=100):
    self.W_c = W_c
    self.W_e = W_e
    self.M_1 = M_1
    self.M_2 = M_2
    self.m1_count = 0
    self.m2_count = 0
    self.r = r
    self.q = q

    self.STS = STS(r, q, max_items)
    self.LTS = LTS(max_items)

  def present(self, x):
    self.STS.present(x)
    self.LTS.update(self.STS.items)

  def retrieve(self):  # retrieve a *single item*
    # if there's anything in STS, retrieve and remove it
    # else:
    #    - sample (from context and/or context + prev item) until we get something other than the previous_recall.
    #             (if previous_recall is the only item left, return None)
    #    - recall (given cue strength):
    #       - if successful, reset m1_count and m2_count, set previous_recall to item, return sampled item
    #       - otherwise increment m1_count or m2_count.  if either exceed M_1/M2, return None
    pass

Other tasks:

  • Fit params to Murdock (1962) dataset that you downloaded with the load_data function.

    • You’ll need to define a “loss” function. I suggest computing MSE for one or more behavioral curves, computed for a subset of the Murdock (1962) participants/lists

    • I suggest using scipy.optimize.minimize to estimate the model parameters.

  • Create observed/predicted plots for held-out data:

    • p(first recall)

    • p(recall)

    • lag-CRP