Bonus Tutorial: Understanding Pre-training, Fine-tuning and Robustness of Transformers
Contents
Bonus Tutorial: Understanding Pre-training, Fine-tuning and Robustness of Transformers¶
Week 2, Day 5: Attention and Transformers
By Neuromatch Academy
Content creators: Bikram Khastgir, Rajaswa Patil, Egor Zverev, Kelson Shilling-Scrivo, Alish Dipani, He He
Content reviewers: Ezekiel Williams, Melvin Selim Atay, Khalid Almubarak, Lily Cheng, Hadi Vafaei, Kelson Shilling-Scrivo
Content editors: Gagana B, Anoop Kulkarni, Spiros Chavlis
Production editors: Khalid Almubarak, Gagana B, Spiros Chavlis
Tutorial Objectives¶
On finishing the tutorial, you will be able to:
Write down the objective of language model pre-training
Understand the framework of pre-training then fine-tuning
Name three types of biases in pre-trained language models
Setup¶
In this section, we will install, and import libraries, as well as helper functions needed for this tutorial.
⚠ Experimental LLM-enhanced tutorial ⚠
This notebook includes Neuromatch’s experimental Chatify 🤖 functionality. The Chatify notebook extension adds support for a large language model-based “coding tutor” to the materials. The tutor provides automatically generated text to help explain any code cell in this notebook.
Note that using Chatify may cause breaking changes and/or provide incorrect or misleading information. If you wish to proceed by installing and enabling the Chatify extension, you should run the next two code blocks (hidden by default). If you do not want to use this experimental version of the Neuromatch materials, please use the stable materials instead.
To use the Chatify helper, insert the %%explain
magic command at the start of any code cell and then run it (shift + enter) to access an interface for receiving LLM-based assitance. You can then select different options from the dropdown menus depending on what sort of assitance you want. To disable Chatify and run the code block as usual, simply delete the %%explain
command and re-run the cell.
Note that, by default, all of Chatify’s responses are generated locally. This often takes several minutes per response. Once you click the “Submit request” button, just be patient– stuff is happening even if you can’t see it right away!
Thanks for giving Chatify a try! Love it? Hate it? Either way, we’d love to hear from you about your Chatify experience! Please consider filling out our brief survey to provide feedback and help us make Chatify more awesome!
Run the next two cells to install and configure Chatify…
%pip install -q davos
import davos
davos.config.suppress_stdout = True
Note: you may need to restart the kernel to use updated packages.
smuggle chatify # pip: git+https://github.com/ContextLab/chatify.git
%load_ext chatify
Using default configuration!
Downloading the 'cache' file.
Install dependencies¶
There may be errors and/or warnings reported during the installation. However, they are to be ignored.
# @title Install dependencies
# @markdown There may be *errors* and/or *warnings* reported during the installation. However, they are to be ignored.
!pip install datasets --quiet
!pip install accelerate --quiet
!pip install transformers --quiet
Install and import feedback gadget¶
# @title Install and import feedback gadget
!pip3 install vibecheck datatops --quiet
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_dl",
"user_key": "f379rz8y",
},
).render()
feedback_prefix = "W2D5_T2_Bonus"
Set environment variables¶
# @title Set environment variables
import os
os.environ['TA_CACHE_DIR'] = 'data/'
os.environ['NLTK_DATA'] = 'nltk_data/'
# Imports
import os
import nltk
import torch
import random
import string
import datasets
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pprint import pprint
from tqdm.notebook import tqdm
from abc import ABC, abstractmethod
from nltk.corpus import brown
from gensim.models import Word2Vec
from sklearn.manifold import TSNE
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchtext.vocab import Vectors
# transformers library
from transformers import Trainer
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import TrainingArguments
from transformers import AutoModelForCausalLM
from transformers import AutoModelForSequenceClassification
%load_ext tensorboard
Set random seed¶
Executing set_seed(seed=seed)
you are setting the seed
# @title Set random seed
# @markdown Executing `set_seed(seed=seed)` you are setting the seed
# for DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html
# Call `set_seed` function in the exercises to ensure reproducibility.
import random
import torch
def set_seed(seed=None, seed_torch=True):
"""
Handles variability by controlling sources of randomness
through set seed values
Args:
seed: Integer
Set the seed value to given integer.
If no seed, set seed value to random integer in the range 2^32
seed_torch: Bool
Seeds the random number generator for all devices to
offer some guarantees on reproducibility
Returns:
Nothing
"""
if seed is None:
seed = np.random.choice(2 ** 32)
random.seed(seed)
np.random.seed(seed)
if seed_torch:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print(f'Random seed {seed} has been set.')
# In case that `DataLoader` is used
def seed_worker(worker_id):
"""
DataLoader will reseed workers following randomness in
multi-process data loading algorithm.
Args:
worker_id: integer
ID of subprocess to seed. 0 means that
the data will be loaded in the main process
Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details
Returns:
Nothing
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
Set device (GPU or CPU). Execute set_device()
¶
# @title Set device (GPU or CPU). Execute `set_device()`
# especially if torch modules used.
# inform the user if the notebook uses GPU or CPU.
def set_device():
"""
Set the device. CUDA if available, CPU otherwise
Args:
None
Returns:
Nothing
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("WARNING: For this notebook to perform best, "
"if possible, in the menu under `Runtime` -> "
"`Change runtime type.` select `GPU` ")
else:
print("GPU is enabled in this notebook.")
return device
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()
Random seed 2021 has been set.
WARNING: For this notebook to perform best, if possible, in the menu under `Runtime` -> `Change runtime type.` select `GPU`
Bonus 1: Language modeling as pre-training¶
Time estimate: ~20mins
Video 1: Pre-training¶
Submit your feedback¶
# @title Submit your feedback
content_review(f"{feedback_prefix}_PreTraining_Video")
Bonus Interactive Demo 1: GPT-2 for sentiment classification¶
In this section, we will use the pre-trained language model GPT-2 for sentiment classification.
Let’s first load the Yelp review dataset.
Download the dataset from OSF¶
# @title Download the dataset from OSF
import requests, tarfile
os.environ['HF_DATASETS_CACHE'] = 'data/'
url = "https://osf.io/kthjg/download"
fname = "huggingface.tar.gz"
if not os.path.exists(fname):
print('Dataset is being downloaded...')
r = requests.get(url, allow_redirects=True)
with open(fname, 'wb') as fd:
fd.write(r.content)
print('Download is finished.')
with tarfile.open(fname) as ft:
ft.extractall('data/')
os.remove(fname)
print('Files have been extracted.')
Load the dataset¶
# @title Load the dataset
DATASET = datasets.load_dataset("yelp_review_full",
download_mode="reuse_dataset_if_exists",
cache_dir='data/')
print(type(DATASET))
<class 'datasets.dataset_dict.DatasetDict'>
# If the above cell produces an error uncomment below and run this cell.
# DATASET = load_dataset("yelp_review_full", ignore_verifications=True)
Bonus 1.1: Load Yelp reviews dataset ⌛🤗¶
# @title Bonus 1.1: Load Yelp reviews dataset ⌛🤗
train_dataset = DATASET['train']
test_dataset = DATASET['test']
# filter training data by sentiment value
sentiment_dict = {}
sentiment_dict["Sentiment = 0"] = train_dataset.filter(lambda example: example['label']==0)
sentiment_dict["Sentiment = 1"] = train_dataset.filter(lambda example: example['label']==1)
sentiment_dict["Sentiment = 2"] = train_dataset.filter(lambda example: example['label']==2)
sentiment_dict["Sentiment = 3"] = train_dataset.filter(lambda example: example['label']==3)
sentiment_dict["Sentiment = 4"] = train_dataset.filter(lambda example: example['label']==4)
Kaggle users: If the cell above fails, please re-execute it several times!
Next, we’ll set up a text context for the pre-trained language models. We can either sample a review from the Yelp reviews dataset or write our own custom review as the text context. We will perform text-generation and sentiment-classification with this text context.
Bonus 1.2: Setting up a text context ✍️¶
# @title Bonus 1.2: Setting up a text context ✍️
def clean_text(text):
"""
Function to clean up text
Args:
text: String
Input text sequence
Returns:
text: String
Returned clean string does not contain new-line characters,
backslashes etc.
"""
text = text.replace("\\n", " ")
text = text.replace("\n", " ")
text = text.replace("\\", " ")
return text
# @markdown ---
sample_review_from_yelp = "Sentiment = 4" # @param ["Sentiment = 0", "Sentiment = 1", "Sentiment = 2", "Sentiment = 3", "Sentiment = 4"]
# @markdown **Randomly sample a response from the Yelp review dataset with the given sentiment value {0:😠, 1:😦, 2:😐, 3:🙂, 4:😀}**
# @markdown ---
use_custom_review = False # @param {type:"boolean"}
custom_review = "I liked this movie very much because ..." # @param {type:"string"}
# @markdown ***Alternatively, write your own review (don't forget to enable custom review using the checkbox given above)***
# @markdown ---
# @markdown **NOTE:** *Run the cell after setting all the You can adding different kinds of extension above fields appropriately!*
print("\n ****** The selected text context ****** \n")
if use_custom_review:
context = clean_text(custom_review)
else:
context = clean_text(sentiment_dict[sample_review_from_yelp][random.randint(0,len(sentiment_dict[sample_review_from_yelp])-1)]["text"])
pprint(context)
****** The selected text context ******
("They carry Eegee's!!!!!!! I love Eegee's but am rarely in Tuscon. Finally I "
'have a place to go satisfy my cravings without having to drive 90 min. '
'Their subs are pretty good too :)')
Here, we’ll ask the pre-trained language models to extend the selected text context further. You can try adding different kinds of extension prompts at the end of the text context, conditioning it for different kinds of text extensions.
Bonus 1.3: Extending the review with pre-trained models 🤖¶
# @title Bonus 1.3: Extending the review with pre-trained models 🤖
# @markdown ---
model = "gpt2" # @param ["gpt2", "gpt2-medium", "xlnet-base-cased"]
generator = pipeline('text-generation', model=model)
set_seed(seed=SEED)
# @markdown **Select a pre-trained language model to generate text 🤖**
# @markdown *(might take some time to download the pre-trained weights for the first time)*
# @markdown ---
extension_prompt = "Hence, overall I feel that ..." # @param {type:"string"}
num_output_responses = 1 # @param {type:"slider", min:1, max:10, step:1}
# @markdown **Provide a prompt to extend the review ✍️**
input_text = context + " " + extension_prompt
# @markdown **NOTE:** *Run this cell after setting all the fields appropriately!*
# @markdown **NOTE:** *Some pre-trained models might not work well with longer texts!*
generated_responses = generator(input_text, max_length=512, num_return_sequences=num_output_responses)
print("\n *********** INPUT PROMPT TO THE MODEL ************ \n")
pprint(input_text)
print("\n *********** EXTENDED RESPONSES BY THE MODEL ************ \n")
for response in generated_responses:
pprint(response["generated_text"][len(input_text):] + " ...")
print()
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/utils/imp │ │ ort_utils.py:1146 in _get_module │ │ │ │ 1143 │ │ │ 1144 │ def _get_module(self, module_name: str): │ │ 1145 │ │ try: │ │ ❱ 1146 │ │ │ return importlib.import_module("." + module_name, self.__name__) │ │ 1147 │ │ except Exception as e: │ │ 1148 │ │ │ raise RuntimeError( │ │ 1149 │ │ │ │ f"Failed to import {self.__name__}.{module_name} because of the followin │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/importlib/__init__.py:127 in │ │ import_module │ │ │ │ 124 │ │ │ if character != '.': │ │ 125 │ │ │ │ break │ │ 126 │ │ │ level += 1 │ │ ❱ 127 │ return _bootstrap._gcd_import(name[level:], package, level) │ │ 128 │ │ 129 │ │ 130 _RELOADING = {} │ │ in _gcd_import:1030 │ │ in _find_and_load:1007 │ │ in _find_and_load_unlocked:986 │ │ in _load_unlocked:680 │ │ in exec_module:850 │ │ in _call_with_frames_removed:228 │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/models/gp │ │ t2/modeling_tf_gpt2.py:30 in <module> │ │ │ │ 27 │ TFCausalLMOutputWithCrossAttentions, │ │ 28 │ TFSequenceClassifierOutputWithPast, │ │ 29 ) │ │ ❱ 30 from ...modeling_tf_utils import ( │ │ 31 │ TFCausalLanguageModelingLoss, │ │ 32 │ TFConv1D, │ │ 33 │ TFModelInputType, │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/modeling_ │ │ tf_utils.py:69 in <module> │ │ │ │ 66 │ │ 67 if parse(tf.__version__) >= parse("2.11.0"): │ │ 68 │ from keras import backend as K │ │ ❱ 69 │ from keras.engine import data_adapter │ │ 70 │ from keras.engine.keras_tensor import KerasTensor │ │ 71 │ from keras.saving.legacy import hdf5_format │ │ 72 else: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ModuleNotFoundError: No module named 'keras.engine' The above exception was the direct cause of the following exception: ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ in <module>:5 │ │ │ │ 2 │ │ 3 # @markdown --- │ │ 4 model = "gpt2" # @param ["gpt2", "gpt2-medium", "xlnet-base-cased"] │ │ ❱ 5 generator = pipeline('text-generation', model=model) │ │ 6 set_seed(seed=SEED) │ │ 7 # @markdown **Select a pre-trained language model to generate text 🤖** │ │ 8 │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/pipelines │ │ /__init__.py:779 in pipeline │ │ │ │ 776 │ # Forced if framework already defined, inferred if it's None │ │ 777 │ # Will load the correct model if possible │ │ 778 │ model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} │ │ ❱ 779 │ framework, model = infer_framework_load_model( │ │ 780 │ │ model, │ │ 781 │ │ model_classes=model_classes, │ │ 782 │ │ config=config, │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/pipelines │ │ /base.py:238 in infer_framework_load_model │ │ │ │ 235 │ │ │ │ │ if _class is not None: │ │ 236 │ │ │ │ │ │ classes.append(_class) │ │ 237 │ │ │ │ if look_tf: │ │ ❱ 238 │ │ │ │ │ _class = getattr(transformers_module, f"TF{architecture}", None) │ │ 239 │ │ │ │ │ if _class is not None: │ │ 240 │ │ │ │ │ │ classes.append(_class) │ │ 241 │ │ │ class_tuple = class_tuple + tuple(classes) │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/utils/imp │ │ ort_utils.py:1137 in __getattr__ │ │ │ │ 1134 │ │ │ value = self._get_module(name) │ │ 1135 │ │ elif name in self._class_to_module.keys(): │ │ 1136 │ │ │ module = self._get_module(self._class_to_module[name]) │ │ ❱ 1137 │ │ │ value = getattr(module, name) │ │ 1138 │ │ else: │ │ 1139 │ │ │ raise AttributeError(f"module {self.__name__} has no attribute {name}") │ │ 1140 │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/utils/imp │ │ ort_utils.py:1136 in __getattr__ │ │ │ │ 1133 │ │ if name in self._modules: │ │ 1134 │ │ │ value = self._get_module(name) │ │ 1135 │ │ elif name in self._class_to_module.keys(): │ │ ❱ 1136 │ │ │ module = self._get_module(self._class_to_module[name]) │ │ 1137 │ │ │ value = getattr(module, name) │ │ 1138 │ │ else: │ │ 1139 │ │ │ raise AttributeError(f"module {self.__name__} has no attribute {name}") │ │ │ │ /Users/jmanning/opt/anaconda3/envs/nma-course/lib/python3.9/site-packages/transformers/utils/imp │ │ ort_utils.py:1148 in _get_module │ │ │ │ 1145 │ │ try: │ │ 1146 │ │ │ return importlib.import_module("." + module_name, self.__name__) │ │ 1147 │ │ except Exception as e: │ │ ❱ 1148 │ │ │ raise RuntimeError( │ │ 1149 │ │ │ │ f"Failed to import {self.__name__}.{module_name} because of the followin │ │ 1150 │ │ │ │ f" traceback):\n{e}" │ │ 1151 │ │ │ ) from e │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: Failed to import transformers.models.gpt2.modeling_tf_gpt2 because of the following error (look up to see its traceback): No module named 'keras.engine'
Next, we’ll ask the pre-trained language models to calculate the likelihood of already existing text-extensions. We can define a positive text-extension as well as a negative text-extension. The sentiment of the given text context can then be determined by comparing the likelihoods of the given text extensions.
(For a positive review, a positive text-extension should ideally be given more likelihood by the pre-trained language model as compared to a negative text-extension. Similarly, for a negative review, the negative text-extension should have more likelihood than the positive text-extension.)
Bonus 1.4: Sentiment binary-classification with likelihood of positive and negative extensions of the review 👍👎¶
# @title Bonus 1.4: Sentiment binary-classification with likelihood of positive and negative extensions of the review 👍👎
# @markdown ---
model_name = "gpt2" # @param ["gpt2", "gpt2-medium", "xlnet-base-cased"]
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# @markdown **Select a pre-trained language model to score the likelihood of extended review**
# @markdown *(might take some time to download the pre-trained weights for the first time)*
# @markdown ---
custom_positive_extension = "I would definitely recommend this!" # @param {type:"string"}
custom_negative_extension = "I would not recommend this!" # @param {type:"string"}
# @markdown **Provide custom positive and negative extensions to the review ✍️**
texts = [context, custom_positive_extension, custom_negative_extension]
encodings = tokenizer(texts)
positive_input_ids = torch.tensor(encodings["input_ids"][0] + encodings["input_ids"][1])
positive_attention_mask = torch.tensor(encodings["attention_mask"][0] + encodings["attention_mask"][1])
positive_label_ids = torch.tensor([-100]*len(encodings["input_ids"][0]) + encodings["input_ids"][1])
outputs = model(input_ids=positive_input_ids,
attention_mask=positive_attention_mask,
labels=positive_label_ids)
positive_extension_likelihood = -1*outputs.loss
print("\nLog-likelihood of positive extension = ", positive_extension_likelihood.item())
negative_input_ids = torch.tensor(encodings["input_ids"][0] + encodings["input_ids"][2])
negative_attention_mask = torch.tensor(encodings["attention_mask"][0] + encodings["attention_mask"][2])
negative_label_ids = torch.tensor([-100]*len(encodings["input_ids"][0]) + encodings["input_ids"][2])
outputs = model(input_ids=negative_input_ids,
attention_mask=negative_attention_mask,
labels=negative_label_ids)
negative_extension_likelihood = -1*outputs.loss
print("\nLog-likelihood of negative extension = ", negative_extension_likelihood.item())
if (positive_extension_likelihood.item() > negative_extension_likelihood.item()):
print("\nPositive text-extension has greater likelihood probabilities!")
print("The given review can be predicted to be POSITIVE 👍")
else:
print("\nNegative text-extension has greater likelihood probabilities!")
print("The given review can be predicted to be NEGATIVE 👎")
# @markdown **NOTE:** *Run this cell after setting all the fields appropriately!*
# @markdown **NOTE:** *Some pre-trained models might not work well with longer texts!*
Log-likelihood of positive extension = -3.4624080657958984
Log-likelihood of negative extension = -3.913834810256958
Positive text-extension has greater likelihood probabilities!
The given review can be predicted to be POSITIVE 👍
Bonus 2: Light-weight fine-tuning¶
Time estimate: ~10mins
Video 2: Fine-tuning¶
Submit your feedback¶
# @title Submit your feedback
content_review(f"{feedback_prefix}_FineTuning_Video")
Fine-tuning these large pre-trained models with billions of parameters tends to be very slow. In this section, we will explore the effect of fine-tuning a few layers (while fixing the others) to save training time.
The HuggingFace python library provides a simplified API for training and fine-tuning transformer language models. In this exercise we will fine-tune a pre-trained language model for sentiment classification.
Bonus 2.1: Data Processing¶
Pre-trained transformer models have a fixed vocabulary of words and sub-words. The input text to a transformer model has to be tokenized into these words and sub-words during the pre-processing stage. We’ll use the HuggingFace tokenizers
to perform the tokenization here.
(By default we’ll use the BERT base-cased pre-trained language model here. You can try using one of the other models available here by changing the model ID values at appropriate places in the code.)
Most of the pre-trained language models have a fixed maximum sequence length. With the HuggingFace tokenizer
library, we can either pad or truncate input text sequences to maximum length with a few lines of code:
# Tokenize the input texts
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
"""
Tokenises incoming sequences;
Args:
examples: Sequence of strings
Sequences to tokenise
Returns:
Returns transformer autotokenizer object with padded, truncated input sequences.
"""
return tokenizer(examples["text"], padding="max_length", truncation=True)
# Here we use the `DATASET` as defined above.
# Recall that DATASET = load_dataset("yelp_review_full", ignore_verifications=True)
tokenized_datasets = DATASET.map(tokenize_function, batched=True)
We’ll randomly sample a subset of the Yelp reviews dataset (10k train samples, 5k samples for validation & testing each). You can include more samples here for better performance (at the cost of longer training times!)
# Select the data splits
train_dataset = tokenized_datasets["train"].shuffle(seed=SEED).select(range(10000))
test_dataset = tokenized_datasets["test"].select(range(0, 2500))
validation_dataset = tokenized_datasets["test"].select(range(2500, 5000))
Bonus 2.2: Model Loading¶
Next, we’ll load a pre-trained checkpoint of the model and decide which layers are to be fine-tuned.
Modify the train_layers
variable below to pick which layers you would like to fine-tune (you can uncomment the print statements for this). Fine-tuning more layers might result in better performance (at the cost of longer training times). Due to computational limitations (limited GPU memory) we cannot fine-tune the entire model.
# Load pre-trained BERT model and freeze layers
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased",
cache_dir="data/",
num_labels=5)
train_layers = ["classifier", "bert.pooler", "bert.encoder.layer.11"] # add/remove layers here (use layer-name sub-strings)
for name, param in model.named_parameters():
if any(x in name for x in train_layers):
param.requires_grad = True
# print("FINE-TUNING -->", name)
else:
param.requires_grad = False
# print("FROZEN -->", name)
Bonus 2.3: Fine-tuning¶
Fine-tune the model! The HuggingFace Trainer
class supports easy fine-tuning and logging. You can play around with various hyperparameters here!
# Setup huggingface trainer
training_args = TrainingArguments(output_dir="yelp_bert",
overwrite_output_dir=True,
evaluation_strategy="epoch",
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
learning_rate=5e-5,
weight_decay=0.0,
num_train_epochs=1, # students may use 5 to see a full training!
fp16=False if DEVICE=='cpu' else True,
save_steps=50,
logging_steps=10,
report_to="tensorboard"
)
We’ll use Accuracy
as the evaluation metric for the sentiment classification task. The HuggingFace datasets
library supports various metrics. You can try experimenting with other classification metrics here!
# Setup evaluation metric
def compute_metrics(eval_pred):
"""
Computes accuracy of the prediction
Args:
eval_pred: Tuple
Logits predicted by the model vs actual labels
Returns:
Dictionary containing accuracy of the prediction
"""
metric = datasets.load_metric("accuracy")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = metric.compute(predictions=predictions, references=labels)["accuracy"]
return {"accuracy": accuracy}
Start the training!
# Instantiate a trainer with training and validation datasets
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
compute_metrics=compute_metrics,
tokenizer=tokenizer
)
# Train the model
if DEVICE != 'cpu':
trainer.train()
# Evaluate the model on the test dataset
if DEVICE != 'cpu':
trainer.evaluate(test_dataset)
We can now visualize the Tensorboard
logs to analyze the training process! The HuggingFace Trainer
class will log various loss values and evaluation metrics automatically!
# Visualize the tensorboard logs
if DEVICE != 'cpu':
%tensorboard --logdir yelp_bert/runs
Bonus 3: Model robustness¶
Time estimate: ~22mins
Video 3: Robustness¶
Submit your feedback¶
# @title Submit your feedback
content_review(f"{feedback_prefix}_Robustness_Video")
Given the previously trained model for sentiment classification, it is possible to deceive it using various text perturbations. The text perturbations can act as previously unseen noise to the model, which might persuade it to impart wrong values of sentiment!
Bonus Interactive Demo 3: Break the model¶
Bonus 3.1: Load an original review¶
# @title Bonus 3.1: Load an original review
def clean_text(text):
"""
Function to clean up text
Args:
text: String
Input text sequence
Returns:
text: String
Returned string does not contain characters new-line characters, backslashes etc.
"""
text = text.replace("\\n", " ")
text = text.replace("\n", " ")
text = text.replace("\\", " ")
return text
# @markdown ---
sample_review_from_yelp = "Sentiment = 4" #@param ["Sentiment = 0", "Sentiment = 1", "Sentiment = 2", "Sentiment = 3", "Sentiment = 4"]
# @markdown **Randomly sample a response from the Yelp review dataset with the given sentiment value {0:😠, 1:😦, 2:😐, 3:🙂, 4:😀}**
# @markdown ---
context = clean_text(sentiment_dict[sample_review_from_yelp][random.randint(0,len(sentiment_dict[sample_review_from_yelp])-1)]["text"])
print("Review for ", sample_review_from_yelp, ":\n")
pprint(context)
Review for Sentiment = 4 :
('I discovered this restaurant when I was living in Montreal. It serves South '
"Indian and Sri Lankan food. It remains my favourite restaurant ever. I'm "
'living in Ottawa now, but every time I visit Montreal, I have to go to '
'Jolee. The decor is nothing fancy, but the food is so delicious. There '
'is a take out counter at the back of the restaurant and people are '
'constantly coming in for take out. The Beef Rolls are amazing. Spicy '
'beef and potato, rolled up and then deep fried. Yummy goodness! The Fish '
'Cutlets are good too (fish and potato). My favourite dish is the Chicken '
'Kottu Roti. The portion is huge and has lots of chicken, egg, onions and '
'roti pieces. The Beef Biriyani is excellent too (lots of beef, cashews, and '
'a boiled egg). I find that the dishes here are spicy. I like spicy, but I '
'can only take a "medium spicy " here. I usually just order mild, which has '
'plenty of bite. The prices are so awesome. The huge Chicken Kottu Roti '
'is like only $7 and can easily feed two people. The Beef Rolls are less '
'than $2 a piece (and they are bigger than usual egg rolls). Their '
"desserts are by the weight and very reasonable in cost. I usually don't "
'like Indian desserts because I find them too sweet, but I love the desserts '
'at Jolee. I usually order a box to go and get a piece of everything. I '
"have to admit that I don't know the names of any of the desserts or what is "
'actually in them, but I know that they are colourful, pretty and delicious '
":) Writing about this place is making me miss it so much. If you haven't "
"tried it yet, you should go, you won't regret it!")
Submit your feedback¶
# @title Submit your feedback
content_review(f"{feedback_prefix}_Load_an_original_review_Interactive_Demo")
We can apply various text perturbations to the selected review using the textattack
python library. This will help us augment the original text to break the model!
Important: Locally or on colab (with !
) you can simple
pip install textattack --quiet
Then, import the packages:
from textattack.augmentation import Augmenter
from textattack.transformations import WordSwapQWERTY
from textattack.transformations import WordSwapExtend
from textattack.transformations import WordSwapContract
from textattack.transformations import WordSwapHomoglyphSwap
from textattack.transformations import CompositeTransformation
from textattack.transformations import WordSwapRandomCharacterDeletion
from textattack.transformations import WordSwapNeighboringCharacterSwap
from textattack.transformations import WordSwapRandomCharacterInsertion
from textattack.transformations import WordSwapRandomCharacterSubstitution
However, as we faced issues, you can run the cell below to load all necessary classes and functions.
Helper functions to avoid textattack
issue¶
# @title Helper functions to avoid `textattack` issue
!pip install flair --quiet
import flair
from collections import OrderedDict
from flair.data import Sentence
"""
Word Swap
-------------------------------
Word swap transformations act by
replacing some words in the input.
Subclasses can implement the abstract WordSwap class by
overriding self._get_replacement_words
"""
def default_class_repr(self):
"""
Formats given input
Args:
None
Returns:
Formatted string with additional parameters
"""
if hasattr(self, "extra_repr_keys"):
extra_params = []
for key in self.extra_repr_keys():
extra_params.append(" (" + key + ")" + ": {" + key + "}")
if len(extra_params):
extra_str = "\n" + "\n".join(extra_params) + "\n"
extra_str = f"({extra_str})"
else:
extra_str = ""
extra_str = extra_str.format(**self.__dict__)
else:
extra_str = ""
return f"{self.__class__.__name__}{extra_str}"
LABEL_COLORS = [
"red",
"green",
"blue",
"purple",
"yellow",
"orange",
"pink",
"cyan",
"gray",
"brown",
]
class Transformation(ABC):
"""
An abstract class for transforming a sequence of text to produce a
potential adversarial example.
"""
def __call__(
self,
current_text,
pre_transformation_constraints=[],
indices_to_modify=None,
shifted_idxs=False,
):
"""
Applies the pre_transformation_constraints then calls
_get_transformations.
Args:
current_text: String
The AttackedText Object to transform.
pre_transformation_constraints: List
The PreTransformationConstraint to apply for cross-checking transformation compatibility.
indices_to_modify: Integer
Word indices to be modified as dictated by the SearchMethod.
shifted_idxs: Boolean
Indicates whether indices could be shifted from their original position in the text.
Returns:
transformed_texts: List
Returns a list of all possible transformations for current_text.
"""
if indices_to_modify is None:
indices_to_modify = set(range(len(current_text.words)))
# If we are modifying all indices, we don't care if some of the indices might have been shifted.
shifted_idxs = False
else:
indices_to_modify = set(indices_to_modify)
if shifted_idxs:
indices_to_modify = set(
current_text.convert_from_original_idxs(indices_to_modify)
)
for constraint in pre_transformation_constraints:
indices_to_modify = indices_to_modify & constraint(current_text, self)
transformed_texts = self._get_transformations(current_text, indices_to_modify)
for text in transformed_texts:
text.attack_attrs["last_transformation"] = self
return transformed_texts
@abstractmethod
def _get_transformations(self, current_text, indices_to_modify):
"""
Returns a list of all possible transformations for current_text,
only modifying indices_to_modify.
Must be overridden by specific transformations.
Args:
current_text: String
The AttackedText Object to transform.
indicies_to_modify: Integer
Specifies word indices which can be modified.
Returns:
Nothing
"""
raise NotImplementedError()
@property
def deterministic(self):
return True
def extra_repr_keys(self):
return []
__repr__ = __str__ = default_class_repr
class WordSwap(Transformation):
"""
An abstract class that takes a sentence and transforms it by replacing
some of its words.
"""
def __init__(self, letters_to_insert=None):
"""
Initializes following attributes
Args:
letters_to_insert: String
Letters allowed for insertion into words (used by some char-based transformations)
Returns:
Nothing
"""
self.letters_to_insert = letters_to_insert
if not self.letters_to_insert:
self.letters_to_insert = string.ascii_letters
def _get_replacement_words(self, word):
"""
Returns a set of replacements given an input word.
Must be overriden by specific word swap transformations.
Args:
word: String
The input word for which replacements are to be found.
Returns:
Nothing
"""
raise NotImplementedError()
def _get_random_letter(self):
"""
Helper function that returns a random single letter from the English
alphabet that could be lowercase or uppercase.
Args:
None
Returns:
Random Single Letter to simulate random-letter transformation
"""
return random.choice(self.letters_to_insert)
def _get_transformations(self, current_text, indices_to_modify):
"""
Returns a list of all possible transformations for current_text,
only modifying indices_to_modify.
Must be overridden by specific transformations.
Args:
current_text: String
The AttackedText Object to transform.
indicies_to_modify: Integer
Which word indices can be modified.
Returns:
transformed_texts: List
List of all transformed texts i.e., index at which transformation was applied
"""
words = current_text.words
transformed_texts = []
for i in indices_to_modify:
word_to_replace = words[i]
replacement_words = self._get_replacement_words(word_to_replace)
transformed_texts_idx = []
for r in replacement_words:
if r == word_to_replace:
continue
transformed_texts_idx.append(current_text.replace_word_at_index(i, r))
transformed_texts.extend(transformed_texts_idx)
return transformed_texts
class WordSwapQWERTY(WordSwap):
"""
A transformation that swaps characters with adjacent keys on a
QWERTY keyboard, replicating the kind of errors that come from typing
too quickly.
"""
def __init__(
self, random_one=True, skip_first_char=False, skip_last_char=False, **kwargs
):
"""
Initiates the following attributes
Args:
random_one: Boolean
Specifies whether to return a single (random) swap, or all possible swaps.
skip_first_char: Boolean
When True, do not modify the first character of each word.
skip_last_char: Boolean
When True, do not modify the last character of each word.
Usage/Example:
>>> from textattack.transformations import WordSwapQWERTY
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapQWERT()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am fabulous.'
>>> augmenter.augment(s)
Returns:
Nothing
"""
super().__init__(**kwargs)
self.random_one = random_one
self.skip_first_char = skip_first_char
self.skip_last_char = skip_last_char
self._keyboard_adjacency = {
"q": [
"w",
"a",
"s",
],
"w": ["q", "e", "a", "s", "d"],
"e": ["w", "s", "d", "f", "r"],
"r": ["e", "d", "f", "g", "t"],
"t": ["r", "f", "g", "h", "y"],
"y": ["t", "g", "h", "j", "u"],
"u": ["y", "h", "j", "k", "i"],
"i": ["u", "j", "k", "l", "o"],
"o": ["i", "k", "l", "p"],
"p": ["o", "l"],
"a": ["q", "w", "s", "z", "x"],
"s": ["q", "w", "e", "a", "d", "z", "x"],
"d": ["w", "e", "r", "f", "c", "x", "s"],
"f": ["e", "r", "t", "g", "v", "c", "d"],
"g": ["r", "t", "y", "h", "b", "v", "d"],
"h": ["t", "y", "u", "g", "j", "b", "n"],
"j": ["y", "u", "i", "k", "m", "n", "h"],
"k": ["u", "i", "o", "l", "m", "j"],
"l": ["i", "o", "p", "k"],
"z": ["a", "s", "x"],
"x": ["s", "d", "z", "c"],
"c": ["x", "d", "f", "v"],
"v": ["c", "f", "g", "b"],
"b": ["v", "g", "h", "n"],
"n": ["b", "h", "j", "m"],
"m": ["n", "j", "k"],
}
def _get_adjacent(self, s):
"""
Helper function to extract keys adjacent to given input key
Args:
s: String
Letter for which adjacent keys are to be queried
Returns:
adjacent_keys: List
List of co-occuring keys with respect to input
"""
s_lower = s.lower()
if s_lower in self._keyboard_adjacency:
adjacent_keys = self._keyboard_adjacency.get(s_lower, [])
if s.isupper():
return [key.upper() for key in adjacent_keys]
else:
return adjacent_keys
else:
return []
def _get_replacement_words(self, word):
"""
Helper function to find candidate words with respect to given input key.
Candidate words are words selected based on nearest neighbors
with scope for subsequent swapping.
Args:
word: String
Word for which candidate words are to be generated.
Returns:
candidate_words: List
List of candidate words with respect to input word.
"""
if len(word) <= 1:
return []
candidate_words = []
start_idx = 1 if self.skip_first_char else 0
end_idx = len(word) - (1 + self.skip_last_char)
if start_idx >= end_idx:
return []
if self.random_one:
i = random.randrange(start_idx, end_idx + 1)
candidate_word = (
word[:i] + random.choice(self._get_adjacent(word[i])) + word[i + 1 :]
)
candidate_words.append(candidate_word)
else:
for i in range(start_idx, end_idx + 1):
for swap_key in self._get_adjacent(word[i]):
candidate_word = word[:i] + swap_key + word[i + 1 :]
candidate_words.append(candidate_word)
return candidate_words
@property
def deterministic(self):
return not self.random_one
EXTENSION_MAP = {"ain't": "isn't", "aren't": 'are not', "can't": 'cannot', "can't've": 'cannot have', "could've": 'could have', "couldn't": 'could not', "didn't": 'did not', "doesn't": 'does not', "don't": 'do not', "hadn't": 'had not', "hasn't": 'has not', "haven't": 'have not', "he'd": 'he would', "he'd've": 'he would have', "he'll": 'he will', "he's": 'he is', "how'd": 'how did', "how'd'y": 'how do you', "how'll": 'how will', "how's": 'how is', "I'd": 'I would', "I'll": 'I will', "I'm": 'I am', "I've": 'I have', "i'd": 'i would', "i'll": 'i will', "i'm": 'i am', "i've": 'i have', "isn't": 'is not', "it'd": 'it would', "it'll": 'it will', "it's": 'it is', "ma'am": 'madam', "might've": 'might have', "mightn't": 'might not', "must've": 'must have', "mustn't": 'must not', "needn't": 'need not', "oughtn't": 'ought not', "shan't": 'shall not', "she'd": 'she would', "she'll": 'she will', "she's": 'she is', "should've": 'should have', "shouldn't": 'should not', "that'd": 'that would', "that's": 'that is', "there'd": 'there would', "there's": 'there is', "they'd": 'they would', "they'll": 'they will', "they're": 'they are', "they've": 'they have', "wasn't": 'was not', "we'd": 'we would', "we'll": 'we will', "we're": 'we are', "we've": 'we have', "weren't": 'were not', "what're": 'what are', "what's": 'what is', "when's": 'when is', "where'd": 'where did', "where's": 'where is', "where've": 'where have', "who'll": 'who will', "who's": 'who is', "who've": 'who have', "why's": 'why is', "won't": 'will not', "would've": 'would have', "wouldn't": 'would not', "you'd": 'you would', "you'd've": 'you would have', "you'll": 'you will', "you're": 'you are', "you've": 'you have'}
class WordSwap(Transformation):
"""
An abstract class that takes a sentence and transforms it by replacing
some of its words.
"""
def __init__(self, letters_to_insert=None):
"""
Initiates the following attributes
Args:
letters_to_insert: String
Letters allowed for insertion into words
(used by some char-based transformations)
Returns:
Nothing
"""
self.letters_to_insert = letters_to_insert
if not self.letters_to_insert:
self.letters_to_insert = string.ascii_letters
def _get_replacement_words(self, word):
"""
Returns a set of replacements given an input word.
Must be overridden by specific word swap transformations.
Args:
word: String
The input word to find replacements for.
Returns:
Nothing
"""
raise NotImplementedError()
def _get_random_letter(self):
"""
Helper function that returns a random single letter from the English
alphabet that could be lowercase or uppercase.
Args:
None
Returns:
Random single letter for random-letter transformation
"""
return random.choice(self.letters_to_insert)
def _get_transformations(self, current_text, indices_to_modify):
"""
Returns a list of all possible transformations for current_text,
only modifying indices_to_modify.
Must be overridden by specific transformations.
Args:
current_text: String
The AttackedText Object to transform.
indicies_to_modify: Integer
Which word indices can be modified.
Returns:
transformed_texts: List
List of all transformed texts with indexes at which transformation was applied
"""
words = current_text.words
transformed_texts = []
for i in indices_to_modify:
word_to_replace = words[i]
replacement_words = self._get_replacement_words(word_to_replace)
transformed_texts_idx = []
for r in replacement_words:
if r == word_to_replace:
continue
transformed_texts_idx.append(current_text.replace_word_at_index(i, r))
transformed_texts.extend(transformed_texts_idx)
return transformed_texts
class WordSwapExtend(WordSwap):
"""
Transforms an input by performing extension on recognized
combinations.
"""
def _get_transformations(self, current_text, indices_to_modify):
"""
Return all possible transformed sentences, each with one extension.
Args:
current_text: String
The AttackedText Object to transform.
indicies_to_modify: Integer
Which word indices can be modified.
Returns:
transformed_texts: List
List of all transformed texts based on extension map
Usage/Examples:
>>> from textattack.transformations import WordSwapExtend
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapExtend()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = '''I'm fabulous'''
>>> augmenter.augment(s)
"""
transformed_texts = []
words = current_text.words
for idx in indices_to_modify:
word = words[idx]
# expend when word in map
if word in EXTENSION_MAP:
expanded = EXTENSION_MAP[word]
transformed_text = current_text.replace_word_at_index(idx, expanded)
transformed_texts.append(transformed_text)
return transformed_texts
class WordSwapContract(WordSwap):
"""
Transforms an input by performing contraction on recognized
combinations.
"""
reverse_contraction_map = {v: k for k, v in EXTENSION_MAP.items()}
def _get_transformations(self, current_text, indices_to_modify):
"""
Return all possible transformed sentences, each with one
contraction.
Args:
current_text: String
The AttackedText Object to transform.
indicies_to_modify: Integer
Which word indices can be modified.
Returns:
transformed_texts: List
List of all transformed texts based on reverse contraction map
Usage/Example:
>>> from textattack.transformations import WordSwapContract
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapContract()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am 12 years old.'
>>> augmenter.augment(s)
"""
transformed_texts = []
words = current_text.words
indices_to_modify = sorted(indices_to_modify)
# search for every 2-words combination in reverse_contraction_map
for idx, word_idx in enumerate(indices_to_modify[:-1]):
next_idx = indices_to_modify[idx + 1]
if (idx + 1) != next_idx:
continue
word = words[word_idx]
next_word = words[next_idx]
# generating the words to search for
key = " ".join([word, next_word])
# when a possible contraction is found in map, contract the current text
if key in self.reverse_contraction_map:
transformed_text = current_text.replace_word_at_index(
idx, self.reverse_contraction_map[key]
)
transformed_text = transformed_text.delete_word_at_index(next_idx)
transformed_texts.append(transformed_text)
return transformed_texts
class WordSwapHomoglyphSwap(WordSwap):
"""
Transforms an input by replacing its words with visually similar words
using homoglyph swaps.
A homoglyph is one of two or more graphemes, characters, or glyphs
with shapes that appear identical or very similar.
"""
def __init__(self, random_one=False, **kwargs):
"""
Initiates the following attributes
Args:
random_one: Boolean
Choosing random substring for transformation
Returns:
Nothing
Usage/Examples:
>>> from textattack.transformations import WordSwapHomoglyphSwap
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapHomoglyphSwap()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am fabulous.'
>>> augmenter.augment(s)
"""
super().__init__(**kwargs)
self.homos = {
"-": "˗",
"9": "৭",
"8": "Ȣ",
"7": "𝟕",
"6": "б",
"5": "Ƽ",
"4": "Ꮞ",
"3": "Ʒ",
"2": "ᒿ",
"1": "l",
"0": "O",
"'": "`",
"a": "ɑ",
"b": "Ь",
"c": "ϲ",
"d": "ԁ",
"e": "е",
"f": "𝚏",
"g": "ɡ",
"h": "հ",
"i": "і",
"j": "ϳ",
"k": "𝒌",
"l": "ⅼ",
"m": "m",
"n": "ո",
"o": "о",
"p": "р",
"q": "ԛ",
"r": "ⲅ",
"s": "ѕ",
"t": "𝚝",
"u": "ս",
"v": "ѵ",
"w": "ԝ",
"x": "×",
"y": "у",
"z": "ᴢ",
}
self.random_one = random_one
def _get_replacement_words(self, word):
"""
Returns a list containing all possible words with 1 character
replaced by a homoglyph.
Args:
word: String
Word for which homoglyphs are to be generated.
Returns:
candidate_words: List
List of homoglyphs with respect to input word.
"""
candidate_words = []
if self.random_one:
i = np.random.randint(0, len(word))
if word[i] in self.homos:
repl_letter = self.homos[word[i]]
candidate_word = word[:i] + repl_letter + word[i + 1 :]
candidate_words.append(candidate_word)
else:
for i in range(len(word)):
if word[i] in self.homos:
repl_letter = self.homos[word[i]]
candidate_word = word[:i] + repl_letter + word[i + 1 :]
candidate_words.append(candidate_word)
return candidate_words
@property
def deterministic(self):
return not self.random_one
def extra_repr_keys(self):
return super().extra_repr_keys()
class WordSwapRandomCharacterDeletion(WordSwap):
"""
Transforms an input by deleting its characters.
"""
def __init__(
self, random_one=True, skip_first_char=False, skip_last_char=False, **kwargs
):
"""
Initiates the following parameters:
Args:
random_one: Boolean
Whether to return a single word with a random
character deleted. If not, returns all possible options.
skip_first_char: Boolean
Whether to disregard deleting the first character.
skip_last_char: Boolean
Whether to disregard deleting the last character.
Returns:
Nothing
Usage/Example:
>>> from textattack.transformations import WordSwapRandomCharacterDeletion
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapRandomCharacterDeletion()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am fabulous.'
>>> augmenter.augment(s)
"""
super().__init__(**kwargs)
self.random_one = random_one
self.skip_first_char = skip_first_char
self.skip_last_char = skip_last_char
def _get_replacement_words(self, word):
"""
Returns a list containing all possible words with 1 letter
deleted.
Args:
word: String
The input word to find replacements for.
Returns:
candidate_words: List
List of candidate words with single letter deletion
"""
if len(word) <= 1:
return []
candidate_words = []
start_idx = 1 if self.skip_first_char else 0
end_idx = (len(word) - 1) if self.skip_last_char else len(word)
if start_idx >= end_idx:
return []
if self.random_one:
i = np.random.randint(start_idx, end_idx)
candidate_word = word[:i] + word[i + 1 :]
candidate_words.append(candidate_word)
else:
for i in range(start_idx, end_idx):
candidate_word = word[:i] + word[i + 1 :]
candidate_words.append(candidate_word)
return candidate_words
@property
def deterministic(self):
return not self.random_one
def extra_repr_keys(self):
return super().extra_repr_keys() + ["random_one"]
class WordSwapNeighboringCharacterSwap(WordSwap):
"""
Transforms an input by replacing its words with a neighboring character
swap.
"""
def __init__(
self, random_one=True, skip_first_char=False, skip_last_char=False, **kwargs
):
"""
Initiates the following attributes
Args:
random_one: Boolean
Whether to return a single word with two characters
swapped. If not, returns all possible options.
skip_first_char: Boolean
Whether to disregard perturbing the first
character.
skip_last_char: Boolean
Whether to disregard perturbing the last
character.
Returns:
Nothing
Usage/Examples:
>>> from textattack.transformations import WordSwapNeighboringCharacterSwap
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapNeighboringCharacterSwap()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am fabulous.'
>>> augmenter.augment(s)
"""
super().__init__(**kwargs)
self.random_one = random_one
self.skip_first_char = skip_first_char
self.skip_last_char = skip_last_char
def _get_replacement_words(self, word):
"""
Returns a list containing all possible words with a single pair of
neighboring characters swapped.
Args:
word: String
The input word to find replacements for.
Returns:
candidate_words: List
List of candidate words
"""
if len(word) <= 1:
return []
candidate_words = []
start_idx = 1 if self.skip_first_char else 0
end_idx = (len(word) - 2) if self.skip_last_char else (len(word) - 1)
if start_idx >= end_idx:
return []
if self.random_one:
i = np.random.randint(start_idx, end_idx)
candidate_word = word[:i] + word[i + 1] + word[i] + word[i + 2 :]
candidate_words.append(candidate_word)
else:
for i in range(start_idx, end_idx):
candidate_word = word[:i] + word[i + 1] + word[i] + word[i + 2 :]
candidate_words.append(candidate_word)
return candidate_words
@property
def deterministic(self):
return not self.random_one
def extra_repr_keys(self):
return super().extra_repr_keys() + ["random_one"]
class WordSwapRandomCharacterInsertion(WordSwap):
"""
Transforms an input by inserting a random character.
"""
def __init__(
self, random_one=True, skip_first_char=False, skip_last_char=False, **kwargs
):
"""
Initiates the following attributes
Args:
random_one: Boolean
Whether to return a single word with a random
character deleted. If not, returns all possible options.
skip_first_char: Boolean
Whether to disregard inserting as the first character.
skip_last_char: Boolean
Whether to disregard inserting as the last character.
Returns:
Nothing
Usage/Example:
>>> from textattack.transformations import WordSwapRandomCharacterInsertion
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapRandomCharacterInsertion()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am fabulous.'
>>> augmenter.augment(s)
"""
super().__init__(**kwargs)
self.random_one = random_one
self.skip_first_char = skip_first_char
self.skip_last_char = skip_last_char
def _get_replacement_words(self, word):
"""
Returns a list containing all possible words with 1 random
character inserted.
Args:
word: String
The input word to find replacements for.
Returns:
candidate_words: List
List of candidate words with all possible words with 1 random
character inserted.
"""
if len(word) <= 1:
return []
candidate_words = []
start_idx = 1 if self.skip_first_char else 0
end_idx = (len(word) - 1) if self.skip_last_char else len(word)
if start_idx >= end_idx:
return []
if self.random_one:
i = np.random.randint(start_idx, end_idx)
candidate_word = word[:i] + self._get_random_letter() + word[i:]
candidate_words.append(candidate_word)
else:
for i in range(start_idx, end_idx):
candidate_word = word[:i] + self._get_random_letter() + word[i:]
candidate_words.append(candidate_word)
return candidate_words
@property
def deterministic(self):
return not self.random_one
def extra_repr_keys(self):
return super().extra_repr_keys() + ["random_one"]
class WordSwapRandomCharacterSubstitution(WordSwap):
"""
Transforms an input by replacing one character in a word with a random
new character.
"""
def __init__(self, random_one=True, **kwargs):
"""
Initiates the following attributes
Args:
random_one: Boolean
Whether to return a single word with a random
character deleted. If not set, returns all possible options.
Returns:
Nothing
Usage/Example:
>>> from textattack.transformations import WordSwapRandomCharacterSubstitution
>>> from textattack.augmentation import Augmenter
>>> transformation = WordSwapRandomCharacterSubstitution()
>>> augmenter = Augmenter(transformation=transformation)
>>> s = 'I am fabulous.'
>>> augmenter.augment(s)
"""
super().__init__(**kwargs)
self.random_one = random_one
def _get_replacement_words(self, word):
"""
Returns a list containing all possible words with 1 letter
substituted for a random letter.
Args:
word: String
The input word to find replacements for.
Returns:
candidate_words: List
List of candidate words with combinations involving random substitution
"""
if len(word) <= 1:
return []
candidate_words = []
if self.random_one:
i = np.random.randint(0, len(word))
candidate_word = word[:i] + self._get_random_letter() + word[i + 1 :]
candidate_words.append(candidate_word)
else:
for i in range(len(word)):
candidate_word = word[:i] + self._get_random_letter() + word[i + 1 :]
candidate_words.append(candidate_word)
return candidate_words
@property
def deterministic(self):
return not self.random_one
def extra_repr_keys(self):
return super().extra_repr_keys() + ["random_one"]
class CompositeTransformation(Transformation):
"""
A transformation which applies each of a list of transformations,
returning a set of all optoins.
"""
def __init__(self, transformations):
"""
Initiates the following attributes
Args:
transformations: List
The list of Transformation to apply.
Returns:
Nothing
"""
if not (
isinstance(transformations, list) or isinstance(transformations, tuple)
):
raise TypeError("transformations must be list or tuple")
elif not len(transformations):
raise ValueError("transformations cannot be empty")
self.transformations = transformations
def _get_transformations(self, *_):
"""
Placeholder method that would throw an error if a user tried to
treat the CompositeTransformation as a 'normal' transformation.
Args:
None
Returns:
Nothing
"""
raise RuntimeError(
"CompositeTransformation does not support _get_transformations()."
)
def __call__(self, *args, **kwargs):
"""
Generates new attacked texts based on different possible transformations
Args:
None
Returns:
new_attacked_texts: List
List of new attacked texts based on different possible transformations
"""
new_attacked_texts = set()
for transformation in self.transformations:
new_attacked_texts.update(transformation(*args, **kwargs))
return list(new_attacked_texts)
def __repr__(self):
main_str = "CompositeTransformation" + "("
transformation_lines = []
for i, transformation in enumerate(self.transformations):
transformation_lines.append(utils.add_indent(f"({i}): {transformation}", 2))
transformation_lines.append(")")
main_str += utils.add_indent("\n" + "\n".join(transformation_lines), 2)
return main_str
__str__ = __repr__
"""
===================
Augmenter Class
===================
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PreTransformationConstraint(ABC):
"""
An abstract class that represents constraints which are applied before
the transformation.
These restrict which words are allowed to be modified during the
transformation. For example, we might not allow stopwords to be
modified.
"""
def __call__(self, current_text, transformation):
"""
Returns the word indices in current_text which are able to be
modified. First checks compatibility with transformation then calls
_get_modifiable_indices
Args:
current_text: String
The AttackedText Object input to consider.
transformation: Transformation Object
The Transformation which will be applied.
Returns:
Modifiable indices of input if transformation is compatible
Words of current text otherwise
"""
if not self.check_compatibility(transformation):
return set(range(len(current_text.words)))
return self._get_modifiable_indices(current_text)
@abstractmethod
def _get_modifiable_indices(current_text):
"""
Returns the word indices in current_text which are able to be
modified. Must be overridden by specific pre-transformation
constraints.
Args:
current_text: String
The AttackedText Object input to consider.
Returns:
Nothing
"""
raise NotImplementedError()
def check_compatibility(self, transformation):
"""
Checks if this constraint is compatible with the given
transformation. For example, the WordEmbeddingDistance constraint
compares the embedding of the word inserted with that of the word
deleted. Therefore it can only be applied in the case of word swaps,
and not for transformations which involve only one of insertion or
deletion.
Args:
transformation: Transformation Object
The Transformation to check compatibility for.
Returns:
True
"""
return True
def extra_repr_keys(self):
"""
Set the extra representation of the constraint using these keys.
To print customized extra information, you should reimplement
this method in your own constraint. Both single-line and multi-
line strings are acceptable.
Args:
None
Returns:
[]
"""
return []
__str__ = __repr__ = default_class_repr
flair.device = device
def words_from_text(s, words_to_ignore=[]):
"""
Lowercases a string, removes all non-alphanumeric characters, and splits
into words.
Args:
s: String
Input String
words_to_ignore: List
List of words that explicitly need to be ignored
Returns:
words: List
Legitimate list of alpha-numeric words that aren't ignored
"""
homos = set(
[
"˗",
"৭",
"Ȣ",
"𝟕",
"б",
"Ƽ",
"Ꮞ",
"Ʒ",
"ᒿ",
"l",
"O",
"`",
"ɑ",
"Ь",
"ϲ",
"ԁ",
"е",
"𝚏",
"ɡ",
"հ",
"і",
"ϳ",
"𝒌",
"ⅼ",
"m",
"ո",
"о",
"р",
"ԛ",
"ⲅ",
"ѕ",
"𝚝",
"ս",
"ѵ",
"ԝ",
"×",
"у",
"ᴢ",
]
)
words = []
word = ""
for c in " ".join(s.split()):
if c.isalnum() or c in homos:
word += c
elif c in "'-_*@" and len(word) > 0:
# Allow apostrophes, hyphens, underscores, asterisks and at signs as long as they don't begin the
# word.
word += c
elif word:
if word not in words_to_ignore:
words.append(word)
word = ""
if len(word) and (word not in words_to_ignore):
words.append(word)
return words
_flair_pos_tagger = None
def flair_tag(sentence, tag_type="upos-fast"):
"""
Tags a Sentence object using flair part-of-speech tagger.
Args:
sentence: Object
Input Sequence
tag_type: String
Type of flair tag that needs to be applied
Returns:
Nothing
"""
global _flair_pos_tagger
if not _flair_pos_tagger:
from flair.models import SequenceTagger
_flair_pos_tagger = SequenceTagger.load(tag_type)
_flair_pos_tagger.predict(sentence)
def zip_flair_result(pred, tag_type="upos-fast"):
"""
Takes a sentence tagging from flair and returns two lists, of words
and their corresponding parts-of-speech.
Args:
pred: Object
Resulting Prediction on input sentence post tagging
tag_type: String
Type of flair tag that needs to be applied
Returns:
Nothing
"""
from flair.data import Sentence
class AttackedText:
"""
A helper class that represents a string that can be attacked.
Models that take multiple sentences as input separate them by SPLIT_TOKEN.
Attacks "see" the entire input, joined into one string, without the split
token.
AttackedText instances that were perturbed from other AttackedText
objects contain a pointer to the previous text
(attack_attrs["previous_attacked_text"]), so that the full chain of
perturbations might be reconstructed by using this key to form a linked
list.
"""
SPLIT_TOKEN = "<SPLIT>"
def __init__(self, text_input, attack_attrs=None):
# Read in ``text_input`` as a string or OrderedDict.
"""
Initiates the following attributes:
Args:
text: String
The string that this AttackedText Object represents
attack_attrs: Dictionary
Dictionary of various attributes stored during the
course of an attack.
Returns:
Nothing
"""
if isinstance(text_input, str):
self._text_input = OrderedDict([("text", text_input)])
elif isinstance(text_input, OrderedDict):
self._text_input = text_input
else:
raise TypeError(
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
)
# Process input lazily.
self._words = None
self._words_per_input = None
self._pos_tags = None
self._ner_tags = None
# Format text inputs.
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
if attack_attrs is None:
self.attack_attrs = dict()
elif isinstance(attack_attrs, dict):
self.attack_attrs = attack_attrs
else:
raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}")
# Indices of words from the *original* text. Allows us to map
# indices between original text and this text, and vice-versa.
self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words))
# A list of all indices in *this* text that have been modified.
self.attack_attrs.setdefault("modified_indices", set())
def __eq__(self, other):
"""
Compares two text instances to make sure they have the same attack
attributes.
Since some elements stored in self.attack_attrs may be numpy
arrays, we have to take special care when comparing them.
Args:
Other: String
Specifies second text instance to be compared for attack attributes
Returns:
True
"""
if not (self.text == other.text):
return False
if len(self.attack_attrs) != len(other.attack_attrs):
return False
for key in self.attack_attrs:
if key not in other.attack_attrs:
return False
elif isinstance(self.attack_attrs[key], np.ndarray):
if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape):
return False
elif not (self.attack_attrs[key] == other.attack_attrs[key]).all():
return False
else:
if not self.attack_attrs[key] == other.attack_attrs[key]:
return False
return True
def __hash__(self):
return hash(self.text)
def free_memory(self):
"""
Delete items that take up memory.
Can be called once the AttackedText is only needed to display.
Args:
None
Returns:
Nothing
"""
if "previous_attacked_text" in self.attack_attrs:
self.attack_attrs["previous_attacked_text"].free_memory()
self.attack_attrs.pop("previous_attacked_text", None)
self.attack_attrs.pop("last_transformation", None)
for key in self.attack_attrs:
if isinstance(self.attack_attrs[key], torch.Tensor):
self.attack_attrs.pop(key, None)
def text_window_around_index(self, index, window_size):
"""
The text window of window_size words centered around
index.
Args:
index: Integer
Index of transformation within input sequence
window_size: Integer
Specifies size of the window around index
Returns:
Substring of text with specified window_size
"""
length = self.num_words
half_size = (window_size - 1) / 2.0
if index - half_size < 0:
start = 0
end = min(window_size - 1, length - 1)
elif index + half_size >= length:
start = max(0, length - window_size)
end = length - 1
else:
start = index - math.ceil(half_size)
end = index + math.floor(half_size)
text_idx_start = self._text_index_of_word_index(start)
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
return self.text[text_idx_start:text_idx_end]
def pos_of_word_index(self, desired_word_idx):
"""
Returns the part-of-speech of the word at index word_idx.
Uses FLAIR part-of-speech tagger.
Args:
desired_word_idx: Integer
Index where POS transformation is to be applied within input sequence
Returns:
Part-of-speech of the word at index word_idx
"""
if not self._pos_tags:
sentence = Sentence(
self.text, use_tokenizer=words_from_text
)
flair_tag(sentence)
self._pos_tags = sentence
flair_word_list, flair_pos_list = zip_flair_result(
self._pos_tags
)
for word_idx, word in enumerate(self.words):
assert (
word in flair_word_list
), "word absent in flair returned part-of-speech tags"
word_idx_in_flair_tags = flair_word_list.index(word)
if word_idx == desired_word_idx:
return flair_pos_list[word_idx_in_flair_tags]
else:
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :]
raise ValueError(
f"Did not find word from index {desired_word_idx} in flair POS tag"
)
def ner_of_word_index(self, desired_word_idx, model_name="ner"):
"""
Returns the ner tag of the word at index word_idx.
Uses FLAIR ner tagger.
Args:
desired_word_idx: Integer
Index where POS transformation is to be applied within input sequence
model_name: String
Name of the model tag that needs to be applied
Returns:
ner tag of the word at index word_idx.
"""
if not self._ner_tags:
sentence = Sentence(
self.text, use_tokenizer = words_from_text
)
flair_tag(sentence, model_name)
self._ner_tags = sentence
flair_word_list, flair_ner_list = zip_flair_result(
self._ner_tags, "ner"
)
for word_idx, word in enumerate(flair_word_list):
word_idx_in_flair_tags = flair_word_list.index(word)
if word_idx == desired_word_idx:
return flair_ner_list[word_idx_in_flair_tags]
else:
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :]
raise ValueError(
f"Did not find word from index {desired_word_idx} in flair POS tag"
)
def _text_index_of_word_index(self, i):
"""
Returns the index of word following i in self.text.
Args:
i: Integer
Index of word upon which perturbation is intended.
Returns:
look_after_index: Index
Index of the word following word[i]
"""
pre_words = self.words[: i + 1]
lower_text = self.text.lower()
# Find all words until `i` in string.
look_after_index = 0
for word in pre_words:
look_after_index = lower_text.find(word.lower(), look_after_index) + len(
word
)
look_after_index -= len(self.words[i])
return look_after_index
def text_until_word_index(self, i):
"""
Returns the text before the beginning of word at index i.
Args:
i: Integer
Index of word upon which perturbation is intended.
Returns:
Text before the beginning of word at index i.
"""
look_after_index = self._text_index_of_word_index(i)
return self.text[:look_after_index]
def text_after_word_index(self, i):
"""
Returns the text after the end of word at index i.
Args:
i: Integer
Index of word upon which perturbation is intended.
Returns:
Text after the end of word at index i.
"""
# Get index of beginning of word then jump to end of word.
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
return self.text[look_after_index:]
def first_word_diff(self, other_attacked_text):
"""
Returns the first word in self.words that differs from
other_attacked_text.
Useful for word swap strategies.
Args:
other_attacked_text: String Object
Sentence/sequence to be compared with given input
Returns:
w1: String
First differing word in self.words if difference exists
None otherwise
"""
w1 = self.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return w1[i]
return None
def first_word_diff_index(self, other_attacked_text):
"""
Returns the index of the first word in self.words that differs from
other_attacked_text.
Useful for word swap strategies.
Args:
other_attacked_text: String object
Sentence/sequence to be compared with given input
Returns:
w1: String
First differing word in self.words if difference exists
None otherwise
"""
w1 = self.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return i
return None
def all_words_diff(self, other_attacked_text):
"""
Returns the set of indices for which this and other_attacked_text
have different words.
Args:
other_attacked_text: String object
Sentence/sequence to be compared with given input
Returns:
indices: Set
differing indices for corresponding words betwee self.words and other_attacked_text
"""
indices = set()
w1 = self.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
indices.add(i)
return indices
def ith_word_diff(self, other_attacked_text, i):
"""
Returns whether the word at index i differs from
other_attacked_text.
Args:
other_attacked_text: String object
Sentence/sequence to be compared with given input
i: Integer
Index of word of interest within input sequence
Returns:
w1: Boolean
Checks for differing words in self.words at index i
"""
w1 = self.words
w2 = other_attacked_text.words
if len(w1) - 1 < i or len(w2) - 1 < i:
return True
return w1[i] != w2[i]
def words_diff_num(self, other_attacked_text):
# using edit distance to calculate words diff num
def generate_tokens(words):
"""
Generates token for given sequence of words
Args:
words: List
Sequence of words
Returns:
result: Dictionary
Word mapped to corresponding index
"""
result = {}
idx = 1
for w in words:
if w not in result:
result[w] = idx
idx += 1
return result
def words_to_tokens(words, tokens):
"""
Helper function to extract corresponding words from tokens
Args:
words: List
Sequence of words
tokens: List
Sequence of tokens
Returns:
result: List
Corresponding token for each word
"""
result = []
for w in words:
result.append(tokens[w])
return result
def edit_distance(w1_t, w2_t):
"""
Function to find the edit distance between given pair of words
Args:
w1_t: String
Input Sequence #1
w2_t: String
Input Sequence #2
Returns:
matrix: 2D Tensor
Distance between each letter in input sequence #1 in
relation to letter in input sequence #2
"""
matrix = [
[i + j for j in range(len(w2_t) + 1)] for i in range(len(w1_t) + 1)
]
for i in range(1, len(w1_t) + 1):
for j in range(1, len(w2_t) + 1):
if w1_t[i - 1] == w2_t[j - 1]:
d = 0
else:
d = 1
matrix[i][j] = min(
matrix[i - 1][j] + 1,
matrix[i][j - 1] + 1,
matrix[i - 1][j - 1] + d,
)
return matrix[len(w1_t)][len(w2_t)]
def cal_dif(w1, w2):
"""
Calculate the edit distance given any pair of characters
Args:
w1: String
Input Character #1
w2: String
Input Character #2
Returns:
Distance between token of input sequence #1 in
relation to token of input sequence #2
"""
tokens = generate_tokens(w1 + w2)
w1_t = words_to_tokens(w1, tokens)
w2_t = words_to_tokens(w2, tokens)
return edit_distance(w1_t, w2_t)
w1 = self.words
w2 = other_attacked_text.words
return cal_dif(w1, w2)
def convert_from_original_idxs(self, idxs):
"""
Takes indices of words from original string and converts them to
indices of the same words in the current string.
Uses information from
self.attack_attrs['original_index_map'], which maps word
indices from the original to perturbed text.
Args:
idxs: List
List of indexes
Returns:
List of mapping of word indices from the original to perturbed text
"""
if len(self.attack_attrs["original_index_map"]) == 0:
return idxs
elif isinstance(idxs, set):
idxs = list(idxs)
elif not isinstance(idxs, [list, np.ndarray]):
raise TypeError(
f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
)
return [self.attack_attrs["original_index_map"][i] for i in idxs]
def replace_words_at_indices(self, indices, new_words):
"""
This code returns a new AttackedText object where the word at
index is replaced with a new word.
Args:
indices: List
List of indexes of words in input sequence
new_words: List
List of words with new word as replacement for original word
Returns:
New AttackedText object where the word at
index is replaced with a new word.
"""
if len(indices) != len(new_words):
raise ValueError(
f"Cannot replace {len(new_words)} words at {len(indices)} indices."
)
words = self.words[:]
for i, new_word in zip(indices, new_words):
if not isinstance(new_word, str):
raise TypeError(
f"replace_words_at_indices requires ``str`` words, got {type(new_word)}"
)
if (i < 0) or (i > len(words)):
raise ValueError(f"Cannot assign word at index {i}")
words[i] = new_word
return self.generate_new_attacked_text(words)
def replace_word_at_index(self, index, new_word):
"""
This code returns a new AttackedText object where the word at
index is replaced with a new word.
Args:
indices: Integer
Index of word
new_word: String
New word for replacement at index of word
Returns:
New AttackedText object where the word at
index is replaced with a new word.
"""
if not isinstance(new_word, str):
raise TypeError(
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
)
return self.replace_words_at_indices([index], [new_word])
def delete_word_at_index(self, index):
"""
This code returns a new AttackedText object where the word at
index is removed.
Args:
index: Integer
Index of word
Returns:
New AttackedText object where the word at
index is removed.
"""
return self.replace_word_at_index(index, "")
def insert_text_after_word_index(self, index, text):
"""
Inserts a string before word at index "index" and attempts to add
appropriate spacing.
Args:
index: Integer
Index of word
text: String
Input Sequence
Returns:
New AttackedText object where new word is inserted
before word at index "index".
"""
if not isinstance(text, str):
raise TypeError(f"text must be an str, got type {type(text)}")
word_at_index = self.words[index]
new_text = " ".join((word_at_index, text))
return self.replace_word_at_index(index, new_text)
def insert_text_before_word_index(self, index, text):
"""
Inserts a string before word at index "index" and attempts to add
appropriate spacing.
Args:
index: Integer
Index of word
text: String
Input Sequence
Returns:
New AttackedText object where the word before
index "index" is replaced with a new word.
"""
if not isinstance(text, str):
raise TypeError(f"text must be an str, got type {type(text)}")
word_at_index = self.words[index]
# TODO if ``word_at_index`` is at the beginning of a sentence, we should
# optionally capitalize ``text``.
new_text = " ".join((text, word_at_index))
return self.replace_word_at_index(index, new_text)
def get_deletion_indices(self):
"""
Returns attack attributes based on corresponding
attributes in original_index_map
Args:
None
Returns:
Attack attributes based on corresponding
attributes in original_index_map
"""
return self.attack_attrs["original_index_map"][
self.attack_attrs["original_index_map"] == -1
]
def generate_new_attacked_text(self, new_words):
"""
Returns a new AttackedText object and replaces old list of words
with a new list of words, but preserves the punctuation and spacing of
the original message.
self.words is a list of the words in the current text with
punctuation removed. However, each "word" in new_words could
be an empty string, representing a word deletion, or a string
with multiple space-separated words, representation an insertion
of one or more words.
Args:
new_words: String
New word for potential replacement
Returns:
TextAttack object with preturbed text and attack attributes
"""
perturbed_text = ""
original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values())
new_attack_attrs = dict()
if "label_names" in self.attack_attrs:
new_attack_attrs["label_names"] = self.attack_attrs["label_names"]
new_attack_attrs["newly_modified_indices"] = set()
# Point to previously monitored text.
new_attack_attrs["previous_attacked_text"] = self
# Use `new_attack_attrs` to track indices with respect to the original
# text.
new_attack_attrs["modified_indices"] = self.attack_attrs[
"modified_indices"
].copy()
new_attack_attrs["original_index_map"] = self.attack_attrs[
"original_index_map"
].copy()
new_i = 0
# Create the new attacked text by swapping out words from the original
# text with a sequence of 0+ words in the new text.
for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)):
word_start = original_text.index(input_word)
word_end = word_start + len(input_word)
perturbed_text += original_text[:word_start]
original_text = original_text[word_end:]
adv_words = words_from_text(adv_word_seq)
adv_num_words = len(adv_words)
num_words_diff = adv_num_words - len(words_from_text(input_word))
# Track indices on insertions and deletions.
if num_words_diff != 0:
# Re-calculated modified indices. If words are inserted or deleted,
# they could change.
shifted_modified_indices = set()
for modified_idx in new_attack_attrs["modified_indices"]:
if modified_idx < i:
shifted_modified_indices.add(modified_idx)
elif modified_idx > i:
shifted_modified_indices.add(modified_idx + num_words_diff)
else:
pass
new_attack_attrs["modified_indices"] = shifted_modified_indices
# Track insertions and deletions wrt original text.
# original_modification_idx = i
new_idx_map = new_attack_attrs["original_index_map"].copy()
if num_words_diff == -1:
# Word deletion
new_idx_map[new_idx_map == i] = -1
new_idx_map[new_idx_map > i] += num_words_diff
if num_words_diff > 0 and input_word != adv_words[0]:
# If insertion happens before the `input_word`
new_idx_map[new_idx_map == i] += num_words_diff
new_attack_attrs["original_index_map"] = new_idx_map
# Move pointer and save indices of new modified words.
for j in range(i, i + adv_num_words):
if input_word != adv_word_seq:
new_attack_attrs["modified_indices"].add(new_i)
new_attack_attrs["newly_modified_indices"].add(new_i)
new_i += 1
# Check spaces for deleted text.
if adv_num_words == 0 and len(original_text):
# Remove extra space (or else there would be two spaces for each
# deleted word).
# @TODO What to do with punctuation in this case? This behavior is undefined.
if i == 0:
# If the first word was deleted, take a subsequent space.
if original_text[0] == " ":
original_text = original_text[1:]
else:
# If a word other than the first was deleted, take a preceding space.
if perturbed_text[-1] == " ":
perturbed_text = perturbed_text[:-1]
# Add substitute word(s) to new sentence.
perturbed_text += adv_word_seq
perturbed_text += original_text # Add all of the ending punctuation.
# Reform perturbed_text into an OrderedDict.
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
perturbed_input = OrderedDict(
zip(self._text_input.keys(), perturbed_input_texts)
)
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
def words_diff_ratio(self, x):
"""
Get the ratio of word differences between current text and x.
Note that current text and x must have same number of words.
Args:
x: String
Compares x with input text for ratio of word differences
Returns:
Ratio of word differences between current text and x.
"""
assert self.num_words == x.num_words
return float(np.sum(self.words != x.words)) / self.num_words
def align_with_model_tokens(self, model_wrapper):
"""
Align AttackedText's words with target model's tokenization scheme
(e.g. word, character, subword).
Specifically, we map each word to list
of indices of tokens that compose the
word (e.g. embedding --> ["em","##bed", "##ding"])
Args:
model_wrapper: textattack.models.wrappers.ModelWrapper
ModelWrapper of the target model
Returns:
word2token_mapping: (dict[int, list[int]])
Dictionary that maps i-th word to list of indices.
"""
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
word2token_mapping = {}
j = 0
last_matched = 0
for i, word in enumerate(self.words):
matched_tokens = []
while j < len(tokens) and len(word) > 0:
token = tokens[j].lower()
idx = word.lower().find(token)
if idx == 0:
word = word[idx + len(token) :]
matched_tokens.append(j)
last_matched = j
j += 1
if not matched_tokens:
word2token_mapping[i] = None
j = last_matched
else:
word2token_mapping[i] = matched_tokens
return word2token_mapping
@property
def tokenizer_input(self):
"""
The tuple of inputs to be passed to the tokenizer.
"""
input_tuple = tuple(self._text_input.values())
# Prefer to return a string instead of a tuple with a single value.
if len(input_tuple) == 1:
return input_tuple[0]
else:
return input_tuple
@property
def column_labels(self):
"""
Returns the labels for this text's columns.
For single-sequence inputs, this simply returns ['text'].
"""
return list(self._text_input.keys())
@property
def words_per_input(self):
"""
Returns a list of lists of words corresponding to each input.
"""
if not self._words_per_input:
self._words_per_input = [
words_from_text(_input) for _input in self._text_input.values()
]
return self._words_per_input
@property
def words(self):
if not self._words:
self._words = words_from_text(self.text)
return self._words
@property
def text(self):
"""
Represents full text input.
Multiply inputs are joined with a line break.
"""
return "\n".join(self._text_input.values())
@property
def num_words(self):
"""
Returns the number of words in the sequence.
"""
return len(self.words)
def printable_text(self, key_color="bold", key_color_method=None):
"""
Represents full text input. Adds field descriptions.
Args:
key_color: String
Field description of input text
key_color_method: String
Color method description of input text
Usage/Example:
entailment inputs look like:
premise: ...
hypothesis: ...
Returns:
Next iterable value for single sequence inputs
Shared field attributes for multi-sequence inputs
"""
# For single-sequence inputs, don't show a prefix.
if len(self._text_input) == 1:
return next(iter(self._text_input.values()))
# For multiple-sequence inputs, show a prefix and a colon. Optionally,
# color the key.
else:
if key_color_method:
def ck(k):
return textattack.shared.utils.color_text(
k, key_color, key_color_method
)
else:
def ck(k):
return k
return "\n".join(
f"{ck(key.capitalize())}: {value}"
for key, value in self._text_input.items()
)
def __repr__(self):
return f'<AttackedText "{self.text}">'
class Augmenter:
"""
A class for performing data augmentation using TextAttack.
"""
def __init__(
self,
transformation,
constraints=[],
pct_words_to_swap=0.1,
transformations_per_example=1,
):
"""
Initiates the following attributes:
Args:
transformation: Transformation Object
The transformation that suggests new texts from an input.
constraints: List
Constraints that each transformation must meet
pct_words_to_swap: Float [0., 1.],
Percentage of words to swap per augmented example
transformations_per_example: Integer
Maximum number of augmentations per input
Returns:
None
"""
assert (
transformations_per_example > 0
), "transformations_per_example must be a positive integer"
assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]"
self.transformation = transformation
self.pct_words_to_swap = pct_words_to_swap
self.transformations_per_example = transformations_per_example
self.constraints = []
self.pre_transformation_constraints = []
for constraint in constraints:
if isinstance(constraint, PreTransformationConstraint):
self.pre_transformation_constraints.append(constraint)
else:
self.constraints.append(constraint)
def _filter_transformations(self, transformed_texts, current_text, original_text):
"""
Filters a list of AttackedText objects to include only the ones
that pass self.constraints.
Args:
Transformed_text: List
List of Strings corresponding to transformations
Current_text: String
String to be compared against for transformation
when original does not meet constraint requirement
Original_text: String
Original Input String
Returns:
All possible transformations for a given string. Currently only
supports transformations which are word swaps.
"""
for C in self.constraints:
if len(transformed_texts) == 0:
break
if C.compare_against_original:
if not original_text:
raise ValueError(
f"Missing `original_text` argument when constraint {type(C)} is set to compare against "
f"`original_text` "
)
transformed_texts = C.call_many(transformed_texts, original_text)
else:
transformed_texts = C.call_many(transformed_texts, current_text)
return transformed_texts
def augment(self, text):
"""
Returns all possible augmentations of text according to
self.transformation.
Args:
text: String
Text to be augmented via transformation
Returns:
Sorted list of all possible augmentations of text according to
compatible self.transformation.
"""
attacked_text = AttackedText(text)
original_text = attacked_text
all_transformed_texts = set()
num_words_to_swap = max(
int(self.pct_words_to_swap * len(attacked_text.words)), 1
)
for _ in range(self.transformations_per_example):
current_text = attacked_text
words_swapped = len(current_text.attack_attrs["modified_indices"])
while words_swapped < num_words_to_swap:
transformed_texts = self.transformation(
current_text, self.pre_transformation_constraints
)
# Get rid of transformations we already have
transformed_texts = [
t for t in transformed_texts if t not in all_transformed_texts
]
# Filter out transformations that don't match the constraints.
transformed_texts = self._filter_transformations(
transformed_texts, current_text, original_text
)
# if there's no more transformed texts after filter, terminate
if not len(transformed_texts):
break
current_text = random.choice(transformed_texts)
# update words_swapped based on modified indices
words_swapped = max(
len(current_text.attack_attrs["modified_indices"]),
words_swapped + 1,
)
all_transformed_texts.add(current_text)
return sorted([at.printable_text() for at in all_transformed_texts])
def augment_many(self, text_list, show_progress=False):
"""
Returns all possible augmentations of a list of strings according to
self.transformation.
Args:
text_list: List of strings
A list of strings for data augmentation
show_progress: Boolean
A variable that controls visibility of Augmentation progress
Returns:
A list(string) of augmented texts.
"""
if show_progress:
text_list = tqdm.tqdm(text_list, desc="Augmenting data...")
return [self.augment(text) for text in text_list]
def augment_text_with_ids(self, text_list, id_list, show_progress=True):
"""
Supplements a list of text with more text data.
Args:
text_list: List of strings
A list of strings for data augmentation
id_list: List of indexes
A list of indexes for corresponding strings
show_progress: Boolean
A variable that controls visibility of augmentation progress
Returns:
all_text_list, all_id_list: List, List
The augmented text along with the corresponding IDs for
each augmented example.
"""
if len(text_list) != len(id_list):
raise ValueError("List of text must be same length as list of IDs")
if self.transformations_per_example == 0:
return text_list, id_list
all_text_list = []
all_id_list = []
if show_progress:
text_list = tqdm.tqdm(text_list, desc="Augmenting data...")
for text, _id in zip(text_list, id_list):
all_text_list.append(text)
all_id_list.append(_id)
augmented_texts = self.augment(text)
all_text_list.extend
all_text_list.extend([text] + augmented_texts)
all_id_list.extend([_id] * (1 + len(augmented_texts)))
return all_text_list, all_id_list
def __repr__(self):
main_str = "Augmenter" + "("
lines = []
# self.transformation
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
# self.constraints
constraints_lines = []
constraints = self.constraints + self.pre_transformation_constraints
if len(constraints):
for i, constraint in enumerate(constraints):
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
else:
constraints_str = "None"
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-macos 2.13.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.
Bonus 3.2: Augment the original review¶
# @title Bonus 3.2: Augment the original review
# @markdown ---
# @markdown Word-level Augmentations
word_swap_contract = True # @param {type:"boolean"}
word_swap_extend = False # @param {type:"boolean"}
word_swap_homoglyph_swap = False # @param {type:"boolean"}
# @markdown ---
# @markdown Character-level Augmentations
word_swap_neighboring_character_swap = True # @param {type:"boolean"}
word_swap_qwerty = False # @param {type:"boolean"}
word_swap_random_character_deletion = False # @param {type:"boolean"}
word_swap_random_character_insertion = False # @param {type:"boolean"}
word_swap_random_character_substitution = False # @param {type:"boolean"}
# @markdown ---
# @markdown Check all the augmentations that you wish to apply!
# @markdown **NOTE:** *Try applying each augmentation individually, and observe the changes.*
# Apply augmentations
augmentations = []
if word_swap_contract:
augmentations.append(WordSwapContract())
if word_swap_extend:
augmentations.append(WordSwapExtend())
if word_swap_homoglyph_swap:
augmentations.append(WordSwapHomoglyphSwap())
if word_swap_neighboring_character_swap:
augmentations.append(WordSwapNeighboringCharacterSwap())
if word_swap_qwerty:
augmentations.append(WordSwapQWERTY())
if word_swap_random_character_deletion:
augmentations.append(WordSwapRandomCharacterDeletion())
if word_swap_random_character_insertion:
augmentations.append(WordSwapRandomCharacterInsertion())
if word_swap_random_character_substitution:
augmentations.append(WordSwapRandomCharacterSubstitution())
transformation = CompositeTransformation(augmentations)
augmenter = Augmenter(transformation=transformation,
transformations_per_example=1)
augmented_review = clean_text(augmenter.augment(context)[0])
print("Augmented review:\n")
pprint(augmented_review)
Augmented review:
('I discovered this restaurant when I was living in Montrael. It serves South '
"Indian and Sri Laknna food. It remains my favourite restaurant ever. I'm "
'living in Ottawa now, but every time I visit Montreal, I hvae to go to '
'Jolee. Teh decor is notihng fancy, but the food is so delicious. There '
'is a take uot counter at the back of the restaurant and people are '
'constantly coming in for take uot. The Beef Rolls aer amazing. Spicy '
'beef and potato, rolled up and then deep fried. Yummy goodness! The Fish '
'Cutlets are good too (fish and potato). yM favourite dish is the Chicken '
'Kottu Roti. Teh portion is huge and has lots of chicken, egg, oinons and '
'roti pieces. The Beef Biriyani is excellent too (lots of beef, cashews, and '
'a obiled egg). I find that hte dishes here are spicy. I like psicy, but I '
'can only take a "meidum spicy " here. I usually just order mild, which has '
'lpenty of bite. Teh prices aer so awesome. The huge Chicken Kottu tRoi '
'si like only $7 and can easily feed two people. The Beef Rolls are less '
'than $2 a piece (and they are biggre than usual egg rolls). Their '
"desserts are by the weight and very reasonable ni cost. I usually don't "
'like Indian desserts because I find them too sweet, but I love the desserts '
'at Jolee. I usually order a box to go and get a pieec fo everything. I '
"have to admit that I don't know the names of any of hte desserts or what is "
"actually in them, but I know that they're colourful, pretty and delicious "
":) Writing about this place is making me miss it so much. If you haven't "
"tried it yet, you should go, you won't regret it!")
We can now check the predictions for the original text and its augmented version! Try to find the perfect combination of perturbations to break the model, i.e., model giving incorrect prediction for the augmented text.
Bonus 3.3: Check model predictions¶
# @title Bonus 3.3: Check model predictions
def getPrediction(text):
"""
Outputs model prediction based on the input text.
Args:
text: String
Input text
Returns:
item of pred: Iterable
Prediction on the input text
"""
inputs = tokenizer(text, padding="max_length",
truncation=True, return_tensors="pt")
for key, value in inputs.items():
inputs[key] = value.to(model.device)
outputs = model(**inputs)
logits = outputs.logits
pred = torch.argmax(logits, dim=1)
return pred.item()
print("original Review:\n")
pprint(context)
print("\nPredicted Sentiment =", getPrediction(context))
print("########################################")
print("\nAugmented Review:\n")
pprint(augmented_review)
print("\nPredicted Sentiment =", getPrediction(augmented_review))
print("########################################")
original Review:
('I discovered this restaurant when I was living in Montreal. It serves South '
"Indian and Sri Lankan food. It remains my favourite restaurant ever. I'm "
'living in Ottawa now, but every time I visit Montreal, I have to go to '
'Jolee. The decor is nothing fancy, but the food is so delicious. There '
'is a take out counter at the back of the restaurant and people are '
'constantly coming in for take out. The Beef Rolls are amazing. Spicy '
'beef and potato, rolled up and then deep fried. Yummy goodness! The Fish '
'Cutlets are good too (fish and potato). My favourite dish is the Chicken '
'Kottu Roti. The portion is huge and has lots of chicken, egg, onions and '
'roti pieces. The Beef Biriyani is excellent too (lots of beef, cashews, and '
'a boiled egg). I find that the dishes here are spicy. I like spicy, but I '
'can only take a "medium spicy " here. I usually just order mild, which has '
'plenty of bite. The prices are so awesome. The huge Chicken Kottu Roti '
'is like only $7 and can easily feed two people. The Beef Rolls are less '
'than $2 a piece (and they are bigger than usual egg rolls). Their '
"desserts are by the weight and very reasonable in cost. I usually don't "
'like Indian desserts because I find them too sweet, but I love the desserts '
'at Jolee. I usually order a box to go and get a piece of everything. I '
"have to admit that I don't know the names of any of the desserts or what is "
'actually in them, but I know that they are colourful, pretty and delicious '
":) Writing about this place is making me miss it so much. If you haven't "
"tried it yet, you should go, you won't regret it!")
Predicted Sentiment = 4
########################################
Augmented Review:
('I discovered this restaurant when I was living in Montrael. It serves South '
"Indian and Sri Laknna food. It remains my favourite restaurant ever. I'm "
'living in Ottawa now, but every time I visit Montreal, I hvae to go to '
'Jolee. Teh decor is notihng fancy, but the food is so delicious. There '
'is a take uot counter at the back of the restaurant and people are '
'constantly coming in for take uot. The Beef Rolls aer amazing. Spicy '
'beef and potato, rolled up and then deep fried. Yummy goodness! The Fish '
'Cutlets are good too (fish and potato). yM favourite dish is the Chicken '
'Kottu Roti. Teh portion is huge and has lots of chicken, egg, oinons and '
'roti pieces. The Beef Biriyani is excellent too (lots of beef, cashews, and '
'a obiled egg). I find that hte dishes here are spicy. I like psicy, but I '
'can only take a "meidum spicy " here. I usually just order mild, which has '
'lpenty of bite. Teh prices aer so awesome. The huge Chicken Kottu tRoi '
'si like only $7 and can easily feed two people. The Beef Rolls are less '
'than $2 a piece (and they are biggre than usual egg rolls). Their '
"desserts are by the weight and very reasonable ni cost. I usually don't "
'like Indian desserts because I find them too sweet, but I love the desserts '
'at Jolee. I usually order a box to go and get a pieec fo everything. I '
"have to admit that I don't know the names of any of hte desserts or what is "
"actually in them, but I know that they're colourful, pretty and delicious "
":) Writing about this place is making me miss it so much. If you haven't "
"tried it yet, you should go, you won't regret it!")
Predicted Sentiment = 1
########################################
Submit your feedback¶
# @title Submit your feedback
content_review(f"{feedback_prefix}_Textattack_module_Interactive_Demos")