5. Dropout
■ 파이토치에서 드롭아웃은 torch.nn.functional에서 dropout 함수로 다음과 같이 사용할 수 있다.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
self.dropout_prob = 0.5 # 몇 퍼센트 드롭아웃할 것인지
def forward(self, x):
x = x.view(-1, 28*28)
x = self.fc1(x)
x = F.relu(x)
x = F.dropout(x, training = self.training, p = self.dropout_prob) # 드롭아웃 적용
x = self.fc2(x)
x = F.relu(x)
x = F.dropout(x, training = self.training, p = self.dropout_prob) # 드롭아웃 적용
x = self.fc3(x)
x = F.log_softmax(x, dim = 1)
return x
- dropout 함수의 training 인수에 training을 지정한 이유는 학습 상태(model.train( ))인지 검증 상태(model.eval( ))인지에 따라 드롭아웃을 적용하기 위해서이다.
- model.train( )으로 학습 과정을 명시하면 training = True, model.eval( )으로 모델 검증 단계임을 명시하면 training = False로 적용되는데, True이면 드롭아웃을 적용하고 False이면 적용하지 않는다.
- 드롭아웃은 학습 과정에서'만' 지정한 드롭아웃 비율만큼의 노드를 선택해서, 선택된 노드들만 가중치가 업데이트되지 않도록 조정하고,
- 평가 과정에서는 모든 노드를 이용해 output을 계산해야 한다. 그러므로 학습 상태와 검증 상태에 따라 다르게 적용해야 한다.
■ torch.nn을 사용할 경우 데이터의 차원에 따라 nn.Dropout1d( ), nn.Dropout2d( ), nn.Dropout3d( ) 함수를 사용하면 된다.
6. 가중치 초깃값
■ 가중치 초깃값 설정은 torch.nn.init을 이용하면 된다.
https://pytorch.org/docs/stable/nn.init.html
torch.nn.init — PyTorch 2.5 documentation
Shortcuts
pytorch.org
■ 예를 들어 위의 코드에서 가중치가 쓰이는 계층인 nn.Linear에 He 초깃값을 적용해 보자.
- nn.Linear는 output으로 계산되는 벡터의 차원 수의 역수 값에 대한 +/- 범위 내 균등 분포로 가중치 값을 샘플링한다.
■ 파이토치에서는 다음과 같이 가중치를 사용하는 계층만 선택하여 가중치 초깃값을 설정할 수 있다.
import torch.nn.init as init
## 가중치 초깃값 설정 메서드 weight_init 정의
def weight_init(m):
if isinstance(m, nn.Linear): # nn.Linear만
init.kaiming_uniform_(m.weight.data) # 가중치 값에 He 초깃값 적용
model = Net().to(DEVICE)
model.apply(weight_init) # weight_init 적용
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5)
criterion = nn.CrossEntropyLoss()
7. 배치 정규화
■ 파이토치에서 nn.BatchNorm( ) 함수를 이용해 배치 정규화 계층을 지정할 수 있다.
■ 주의할 점은 차원에 따라 적용되는 함수명이 다르기 때문에 상황에 맞춰 적합한 배치 정규화 함수를 사용해야 한다.
- 1 차원에는 nn.BatchNorm1d( ), 2 차원에는 nn.BatchNorm2d( ), 3 차원에는 nn.BatchNorm3d( )를 사용해야 한다.
- 위의 class Net같이 MLP 완전연결 계층은 입력 데이터를 1차원으로 받아 각 계층에서의 데이터는 1차원 크기의 벡터 값을 계산한다. 이런 경우에는 nn.BatchNorm1d( )를 이용해야 한다.
■ 배치 정규화 계층은 연구자의 선호도에 따라 활성화 함수 이전에 적용하거나 이후에 적용한다.
■ 예를 들어 위의 코드에서 활성화 함수 이전에 배치 정규화를 적용하면
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
self.dropout_prob = 0.5 #
self.batch_norm1 = nn.BatchNorm1d(512) # 배치 정규화 계층 1
self.batch_norm2 = nn.BatchNorm1d(256) # 배치 정규화 계층 2
def forward(self, x):
x = x.view(-1, 28*28)
x = self.fc1(x)
x = self.batch_norm1(x) # 배치 정규화 계층 1 적용
x = F.relu(x) # 활성화 함수
x = F.dropout(x, training = self.training, p = self.dropout_prob) # 드롭아웃 적용
x = self.fc2(x)
x = self.batch_norm2(x) # 배치 정규화 계층 2 적용
x = F.relu(x) # 활성화 함수
x = F.dropout(x, training = self.training, p = self.dropout_prob) # 드롭아웃 적용
x = self.fc3(x)
x = F.log_softmax(x, dim = 1)
return x
■ 위의 코드에서 첫 번째 배치 정규화 계층의 인수에 512, 두 번째 배치 정규화 계층의 인수에 256을 넣었는데, forward 함수의 내부에서 모델 구조를 보면 첫 번째 배치 정규화 계층은 fc1과 fc2 사이, 두 번째 배치 정규화 계층은 fc2와 fc3 사이이다.
■ 여기서 fc1의 출력과 fc2의 입력은 512 크기의 벡터 값, fc2의 출력과 fc3의 입력은 256 크기의 벡터 값이다. 즉, nn.BatchNorm1d( )에 넣는 값은 이전 계층의 출력 크기(= 다음 계층의 입력 크기)이다.
8. 학습률 스케줄러
■ 파이토치에서 학습률 스케줄러는 다음과 같이 옵티마이저를 먼저 정의한 다음, torch.optim의 lr_scheduler를 통해 스케줄러를 정의할 수 있다.
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
■ 학습률 스케줄러의 종류
- 1) torch.optim.lr_scheduler.LambdaLR: 람다(lambda) 함수를 이용해 그 결과를 학습률로 설정
- 2) torch.optim.lr_scheduler.StepLR: 단계(step)마다 학습률을 감마(gamma) 비율만큼 감소
- 3) torch.optim.lr_scheduler.MultiStepLR: StepLR과 비슷하지만 특정 단계가 아니라 지정된 에포크에만 감마 비율로 감소
- 4) torch.optim.lr_scheduler.ExponentialLR: 에포크마다 이전 학습률에 감마만큼 곱함
- 5) torch.optim.lr_scheduler.CosineAnnealingLR: 학습률을 코사인(cosine) 함수의 형태처럼 변화시켜 학습률일 커지기도 하고 작아지기도 함
- 6) torch.optim.lr_scheduler.ReduceLROnPlateau: 학습이 잘되는지 아닌지에 따라 동적으로 학습률 변화
- 이외에도 다양한 학습률 스케줄러가 있다.
https://pytorch.org/docs/stable/optim.html
torch.optim — PyTorch 2.5 documentation
Shortcuts
pytorch.org
'파이토치' 카테고리의 다른 글
파이토치 합성곱 신경망(CNN) (2) (2) | 2024.12.06 |
---|---|
파이토치 합성곱 신경망(CNN) (1) (2) | 2024.12.05 |
torch.nn (1) (0) | 2024.12.04 |