Related Articles
Sentiment Analysis Fine-Tuning with BERT
By Yangming Li
Introduction
In this article, we explore the process of fine-tuning a BERT model for sentiment analysis using the Peft library and Lora technique. We will discuss the practical steps and provide a detailed code example.
Understanding Peft and Lora
1. Peft Library
The Peft library (Parameter-Efficient Fine-Tuning) is designed for efficient parameter tuning in deep learning. It enables fine-tuning of pre-trained models with minimal data and computational resources.
2. Lora Technique
Lora (Low-Rank Adaptation) is a fine-tuning method that introduces low-rank matrices to expand the parameter space of pre-trained models, enhancing their representation and generalization abilities without increasing complexity.
Practical Steps for Implementation
Data Preparation
First, prepare a dataset for text sentiment classification. It should contain text content and corresponding sentiment labels (e.g., positive, negative). Divide the dataset into training, validation, and test sets.
Model Loading
Load the pre-trained BERT model and tokenizer using the Hugging Face Transformers library. This serves as the foundation for our fine-tuning process.
Define Lora Fine-Tuning Layer
Define the Lora fine-tuning layer using the Peft library. This involves introducing low-rank matrices to the pre-trained model's output layer.
Configure Training Parameters
Set hyperparameters such as learning rate, batch size, and epochs. The Peft library provides methods for efficient parameter optimization.
Training and Evaluation
Train the model using the training dataset and evaluate its performance on the validation set. Fine-tune the model parameters using gradient descent.
Code Example
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertForSequenceClassification, BertTokenizer
from peft import LoraConfig, LoraModel
# Custom dataset class for text sentiment classification
class SentimentDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
inputs = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
inputs = {key: val.squeeze(0) for key, val in inputs.items()}
return inputs, label
# Example data (replace with your actual dataset)
train_texts = ["I love this product!", "This is the worst service I've ever received."]
train_labels = [1, 0] # 1: positive, 0: negative
val_texts = ["Amazing experience.", "Not happy with the quality."]
val_labels = [1, 0]
# Load pre-trained model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define Lora fine-tuning layer
lora_config = LoraConfig(model, num_lora_layers=2)
lora_model = LoraModel(model, lora_config)
# Prepare datasets and dataloaders
max_length = 128
train_dataset = SentimentDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = SentimentDataset(val_texts, val_labels, tokenizer, max_length)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Configure training parameters
learning_rate = 1e-5
epochs = 3
# Define optimizer and loss function
optimizer = torch.optim.Adam(lora_model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()
# Training loop
for epoch in range(epochs):
lora_model.train()
for batch_inputs, batch_labels in train_loader:
optimizer.zero_grad()
outputs = lora_model(**batch_inputs, labels=batch_labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{epochs} - Training loss: {loss.item()}")
# Evaluation loop
lora_model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_inputs, batch_labels in val_loader:
outputs = lora_model(**batch_inputs)
_, predicted = torch.max(outputs.logits, 1)
total += batch_labels.size(0)
correct += (predicted == batch_labels).sum().item()
accuracy = correct / total
print(f"Validation Accuracy: {accuracy * 100:.2f}%")