FASHION MNIST
FashionMNIST 데이터셋으로 간단한 뉴럴 네트워크를 훈련하고 추론하는 코드입니다. FashionMNIST Dataset Link : https://www.kaggle.com/datasets/zalando-research/fashionmnist
- class hnvlib.fashion_mnist.FashionMNISTNetwork
FashionMNIST 데이터를 훈련할 모델을 정의합니다.
- forward(x: torch.Tensor) torch.Tensor
피드 포워드(순전파)를 진행하는 함수입니다.
- 매개변수
x (Tensor) – 입력 이미지
- 반환
입력 이미지에 대한 예측값
- 반환 형식
Tensor
- class hnvlib.fashion_mnist.FashionMNISTNetworkModule
- configure_optimizers() torch.optim.optimizer.Optimizer
옵티마이저를 정의합니다.
- 반환
파이토치 옵티마이저
- 반환 형식
torch.optim.Optimizer
- forward(x: torch.Tensor) torch.Tensor
피드 포워딩
- 매개변수
x (Tensor) – 입력 이미지
- 반환
입력 이미지에 대한 예측값
- 반환 형식
Tensor
- hnvlib.fashion_mnist.predict(test_data: torch.utils.data.dataset.Dataset, model: torch.nn.modules.module.Module) None
학습한 뉴럴 네트워크로 FashionMNIST 데이터셋을 분류합니다.
- 매개변수
test_data (Dataset) – 추론에 사용되는 데이터셋
model (nn.Module) – 추론에 사용되는 모델
- hnvlib.fashion_mnist.run_pytorch_lightning(batch_size: int, epochs: int) None
학습/추론 파이토치 라이트닝 파이프라인입니다.
- hnvlib.fashion_mnist.test(dataloader: torch.utils.data.dataloader.DataLoader, device: str, model: torch.nn.modules.module.Module, loss_fn: torch.nn.modules.module.Module) None
FashionMNIST 데이터셋으로 뉴럴 네트워크의 성능을 테스트합니다.
- 매개변수
dataloader (DataLoader) – 파이토치 데이터로더
device (str) – 훈련에 사용되는 장치
model (nn.Module) – 훈련에 사용되는 모델
loss_fn (nn.Module) – 훈련에 사용되는 오차 함수
- hnvlib.fashion_mnist.train(dataloader: torch.utils.data.dataloader.DataLoader, device: str, model: torch.nn.modules.module.Module, loss_fn: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer) None
FashionMNIST 데이터셋으로 뉴럴 네트워크를 훈련합니다.
- 매개변수
dataloader (DataLoader) – 파이토치 데이터로더
device (str) – 훈련에 사용되는 장치
model (nn.Module) – 훈련에 사용되는 모델
loss_fn (nn.Module) – 훈련에 사용되는 오차 함수
optimizer (torch.optim.Optimizer) – 훈련에 사용되는 옵티마이저