CIFAR-10

CIFAR-10 데이터셋으로 간단한 뉴럴 네트워크를 훈련하고 추론하는 코드입니다. CIFAR-10 Dataset Link : https://www.cs.toronto.edu/~kriz/cifar.html

class hnvlib.cifar10.CIFARNetwork

학습과 추론에 사용되는 간단한 뉴럴 네트워크입니다.

forward(x: torch.Tensor) torch.Tensor

피드 포워드(순전파)를 진행하는 함수입니다.

매개변수

x (Tensor) – 입력 이미지

반환

입력 이미지에 대한 예측값 (클래스값)

반환 형식

Tensor

class hnvlib.cifar10.CIFARNetworkModule

모델과 학습/추론 코드가 포함된 파이토치 라이트닝 모듈입니다.

configure_optimizers() torch.optim.optimizer.Optimizer

옵티마이저를 정의합니다.

반환

파이토치 옵티마이저

반환 형식

torch.optim.Optimizer

forward(x: torch.Tensor) torch.Tensor

피드 포워딩 함수

매개변수

x (Tensor) – 입력 이미지

반환

입력 이미지에 대한 예측값

반환 형식

Tensor

training_step(batch: torch.Tensor, batch_idx: int) Dict[str, float]

뉴럴 네트워크를 한 스텝 훈련합니다.

매개변수
  • batch (int) – 훈련 데이터셋의 배치 크기

  • batch_idx (int) – 배치에 대한 인덱스

반환

훈련 오차 데이터

반환 형식

Dict[str, float]

validation_epoch_end(outputs: List[torch.Tensor]) None

한 에폭 검증을 마치고 실행되는 코드입니다.

매개변수

outputs (List[Tensor]) – 함수 validation_step에서 반환한 값들을 한 에폭이 끝나는 동안 모은 값들의 집합

validation_step(batch: torch.Tensor, batch_idx: int) Dict[str, float]

훈련 후 한 배치를 검증합니다.

매개변수
  • batch (int) – 검증 데이터셋의 배치 크기

  • batch_idx (int) – 배치에 대한 인덱스

반환

검증 오차 데이터

반환 형식

Dict[str, float]

hnvlib.cifar10.predict(test_data: torch.utils.data.dataset.Dataset, model: torch.nn.modules.module.Module) None

학습한 뉴럴 네트워크로 CIFAR-10 데이터셋을 분류합니다.

매개변수
  • test_data (Dataset) – 추론에 사용되는 데이터셋

  • model (nn.Module) – 추론에 사용되는 모델

hnvlib.cifar10.run_pytorch(batch_size: int, epochs: int) None

학습/추론 파이토치 파이프라인입니다.

매개변수
  • batch_size (int) – 학습 및 추론 데이터셋의 배치 크기

  • epochs (int) – 전체 학습 데이터셋을 훈련하는 횟수

hnvlib.cifar10.run_pytorch_lightning(batch_size: int, epochs: int) None

학습/추론 파이토치 라이트닝 파이프라인입니다.

매개변수
  • batch_size (int) – 학습 및 추론 데이터셋의 배치 크기

  • epochs (int) – 전체 학습 데이터셋을 훈련하는 횟수

hnvlib.cifar10.test(dataloader: torch.utils.data.dataloader.DataLoader, device: hnvlib.cifar10._device, model: torch.nn.modules.module.Module, loss_fn: torch.nn.modules.module.Module) None

CIFAR-10 데이터셋으로 뉴럴 네트워크의 성능을 테스트합니다.

매개변수
  • dataloader (DataLoader) – 파이토치 데이터로더

  • device (_device) – 테스트에 사용되는 장치

  • model (nn.Module) – 테스트에 사용되는 모델

  • loss_fn (nn.Module) – 테스트에 사용되는 오차 함수

hnvlib.cifar10.train(dataloader: torch.utils.data.dataloader.DataLoader, device: hnvlib.cifar10._device, model: torch.nn.modules.module.Module, loss_fn: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer) None

CIFAR-10 데이터셋으로 뉴럴 네트워크를 훈련합니다.

매개변수
  • dataloader (DataLoader) – 파이토치 데이터로더

  • device (_device) – 훈련에 사용되는 장치

  • model (nn.Module) – 훈련에 사용되는 모델

  • loss_fn (nn.Module) – 훈련에 사용되는 오차 함수

  • optimizer (torch.optim.Optimizer) – 훈련에 사용되는 옵티마이저