본문 바로가기

Deep Learning

Continual Learning(연속학습)

일단 가장 기초적인 딥러닝 모델은 고정된 데이터셋에서 학습을 수행하고, 학습이 완료된 후에는 새로운 데이터나 작업에 대해 다시 처음부터 재학습해야 하는 경우가 많다. 하지만 실제 환경에서는 데이터가 시간에 따라 지속적으로 유입되거나, 새로운 class가 생성되거나 하는 경우가 있어, 모델이 기존 작업 외에도 새로운 작업을 학습해야 하는 경우가 많다. 이때 이전 지식을 유지하면서 새로운 지식을 습득하는 학습 방식continual learning 또는 lifelong learning이라고 한다. 

 

Continual Learning이 중요한 이유는...

기존 딥러닝 모델은 새로운 작업을 학습할 때 catastrophic forgetting(기존 지식의 급격한 손실) 문제를 겪는다. Continual learning은 이 문제를 해결하여, 모델이 과거에 학습했던 정보도 잘 기억하면서 새로운 지식도 효율적으로 습득할 수 있도록 하는데 그 목적이 있다. 

 

Continual Learning의 주요 접근법으로는 아래 3가지가 있다. 

 

  • Regularization-based methods
    • 중요한 파라미터의 변경을 억제하여 기존 지식을 보존
    • 대표: Elastic Weight Consolidation (EWC)
  • Replay-based methods
    • 과거 데이터를 일부 저장하거나 생성하여 새 데이터와 함께 학습
    • 대표: Experience Replay, Generative Replay
  • Parameter isolation methods
    • 작업마다 별도의 파라미터 공간을 사용해 간섭 방지
    • 대표: Progressive Neural Networks

 

아래의 간단한 예제를 보자. 

두 개의 간단한 작업을 순차적으로 학습해보는 continual learning 예제이다. 
Task 1: 0~4 숫자 구분
Task 2: 5~9 숫자 구분

 

먼저 catastrophic forgetting이 실제로 발생하는지 예제로 확인해보자. 

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# 단순 MLP 모델
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)

# 데이터 준비
transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Task 1: 0~4 데이터
task1_train_idx = [i for i, (x, y) in enumerate(train_data) if y < 5]
task1_train_data = Subset(train_data, task1_train_idx)
task1_train_loader = DataLoader(task1_train_data, batch_size=64, shuffle=True)

task1_test_idx = [i for i, (x, y) in enumerate(test_data) if y < 5]
task1_test_data = Subset(test_data, task1_test_idx)
task1_test_loader = DataLoader(task1_test_data, batch_size=64, shuffle=False)

# Task 2: 5~9 데이터
task2_train_idx = [i for i, (x, y) in enumerate(train_data) if y >= 5]
task2_train_data = Subset(train_data, task2_train_idx)
task2_train_loader = DataLoader(task2_train_data, batch_size=64, shuffle=True)

task2_test_idx = [i for i, (x, y) in enumerate(test_data) if y >= 5]
task2_test_data = Subset(test_data, task2_test_idx)
task2_test_loader = DataLoader(task2_test_data, batch_size=64, shuffle=False)

# 모델, 손실함수, 옵티마이저
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 학습 함수
def train(model, loader):
    model.train()
    for epoch in range(1):  # epoch 수 늘려도 됨
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

# 테스트 함수
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = 100. * correct / total
    return acc

# Task 1 학습
train(model, task1_train_loader)
task1_acc_after_task1 = test(model, task1_test_loader)
print(f'Task 1 학습 후 Task 1 정확도: {task1_acc_after_task1:.2f}%')

# Task 2 학습
train(model, task2_train_loader)
task1_acc_after_task2 = test(model, task1_test_loader)
task2_acc_after_task2 = test(model, task2_test_loader)
print(f'Task 2 학습 후 Task 1 정확도: {task1_acc_after_task2:.2f}%')
print(f'Task 2 학습 후 Task 2 정확도: {task2_acc_after_task2:.2f}%')

 

 

해당 코드를 실행해보면, 아래와 같은 결과가 나온다. 

Task2 학습 후, Task1의 성능이 급격히 떨어졌다.

 

이제 Continual Learning 기법 중 Replay 기법을 사용하여 위의 문제를 해결해보자. 

코드는 아래와 같다. 

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset, ConcatDataset
import random

# 단순 MLP 모델
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)

# 데이터 준비
transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Task 1: 0~4 데이터
task1_train_idx = [i for i, (x, y) in enumerate(train_data) if y < 5]
task1_train_data = Subset(train_data, task1_train_idx)
task1_train_loader = DataLoader(task1_train_data, batch_size=64, shuffle=True)

task1_test_idx = [i for i, (x, y) in enumerate(test_data) if y < 5]
task1_test_data = Subset(test_data, task1_test_idx)
task1_test_loader = DataLoader(task1_test_data, batch_size=64, shuffle=False)

# Task 2: 5~9 데이터
task2_train_idx = [i for i, (x, y) in enumerate(train_data) if y >= 5]
task2_train_data = Subset(train_data, task2_train_idx)
task2_train_loader = DataLoader(task2_train_data, batch_size=64, shuffle=True)

task2_test_idx = [i for i, (x, y) in enumerate(test_data) if y >= 5]
task2_test_data = Subset(test_data, task2_test_idx)
task2_test_loader = DataLoader(task2_test_data, batch_size=64, shuffle=False)

# 모델, 손실함수, 옵티마이저
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 학습 함수
def train(model, loader):
    model.train()
    for epoch in range(1):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

# 테스트 함수
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = 100. * correct / total
    return acc

# Replay Buffer (Task 1 데이터 중 일부)
def create_replay_buffer(loader, ratio=0.1):
    replay_x, replay_y = [], []
    for x, y in loader:
        replay_x.append(x)
        replay_y.append(y)
    replay_x = torch.cat(replay_x)
    replay_y = torch.cat(replay_y)

    n_replay = int(len(replay_x) * ratio)
    idx = random.sample(range(len(replay_x)), n_replay)

    replay_x = replay_x[idx]
    replay_y = replay_y[idx]

    replay_dataset = TensorDataset(replay_x, replay_y)
    return replay_dataset

# Task 1 학습
train(model, task1_train_loader)
task1_acc_after_task1 = test(model, task1_test_loader)
print(f'Task 1 학습 후 Task 1 정확도: {task1_acc_after_task1:.2f}%')

# Replay Buffer 생성
replay_dataset = create_replay_buffer(task1_train_loader, ratio=0.1)

# Task 2 + Replay 데이터셋 결합
combined_dataset = ConcatDataset([task2_train_data, replay_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)

# Task 2 학습 (Replay 포함)
train(model, combined_loader)

# 성능 측정
task1_acc_after_task2 = test(model, task1_test_loader)
task2_acc_after_task2 = test(model, task2_test_loader)
print(f'Task 2 학습 (Replay 포함) 후 Task 1 정확도: {task1_acc_after_task2:.2f}%')
print(f'Task 2 학습 (Replay 포함) 후 Task 2 정확도: {task2_acc_after_task2:.2f}%')

 

이렇게 코드를 수정하고 실행하면 아래와 같은 결과가 나온다. 

Task2 학습 이후에도, Task1의 정확도가 41% 정도 된다.

 

만족스러운 결과는 아니지만, 그래도 7.5%에서 40.8%로 많은 개선이 있다는 희망이 보인다. 

 

이처럼 처음 학습 성능을 가능한 유지하면서 새로운 task를 진행하는 것이 continual learning이다. 

 

그럼 이만~

'Deep Learning' 카테고리의 다른 글

Simple Imitation Learning  (0) 2025.05.15
Simple Object Detection with DETR  (1) 2025.05.13
Simple CNN 예제  (0) 2025.05.07
Multi-task learning  (0) 2025.04.30
강화학습(Reinforcement Learning) 최신 동향  (0) 2025.04.15