GLOBAL WHEAT DETECTION

Wheat 데이터셋으로 간단한 뉴럴 네트워크를 훈련하고 추론하는 코드입니다. Wheat Dataset Link : https://www.kaggle.com/c/global-wheat-detection

class hnvlib.global_wheat_detection.WheatDataset(image_dir: os.PathLike, csv_path: os.PathLike, transform: Optional[Sequence[Callable]] = None)

Wheat 데이터셋 사용자 정의 클래스를 정의합니다.

class hnvlib.global_wheat_detection.WheatDetectionModule(csv_path, lr)
configure_optimizers()

_summary_

Returns:

_type_: _description_

forward(x)

_summary_

Args:

x (_type_): _description_

Returns:

_type_: _description_

training_step(batch, batch_idx)

_summary_

Args:

batch (_type_): _description_ batch_idx (_type_): _description_

Returns:

_type_: _description_

validation_epoch_end(outputs)

_summary_

Args:

outputs (_type_): _description_

validation_step(batch, batch_idx)

_summary_

Args:

batch (_type_): _description_ batch_idx (_type_): _description_

Returns:

_type_: _description_

hnvlib.global_wheat_detection.run_pytorch(csv_path: os.PathLike, train_image_dir: os.PathLike, train_csv_path: os.PathLike, test_csv_path: os.PathLike, batch_size: int, epochs: int, lr: float) None

학습/추론 파이토치 파이프라인입니다.

매개변수
  • batch_size (int) – 학습 및 추론 데이터셋의 배치 크기

  • epochs (int) – 전체 학습 데이터셋을 훈련하는 횟수

hnvlib.global_wheat_detection.run_pytorch_lightning(csv_path: os.PathLike, train_image_dir: os.PathLike, train_csv_path: os.PathLike, test_csv_path: os.PathLike, batch_size: int, epochs: int, lr: float) None

_summary_

Args:

csv_path (os.PathLike): _description_ train_image_dir (os.PathLike): _description_ train_csv_path (os.PathLike): _description_ test_csv_path (os.PathLike): _description_ batch_size (int): _description_ epochs (int): _description_ lr (float): _description_

hnvlib.global_wheat_detection.split_dataset(csv_path: os.PathLike, split_rate: float = 0.2) None

Dirty-MNIST 데이터셋을 비율에 맞춰 train / test로 나눕니다.

매개변수
  • path (os.PathLike) – Dirty-MNIST 데이터셋 경로

  • split_rate (float) – train과 test로 데이터 나누는 비율

hnvlib.global_wheat_detection.test(dataloader: torch.utils.data.dataloader.DataLoader, device: hnvlib.global_wheat_detection._device, model: torch.nn.modules.module.Module, metric) None

CIFAR-10 데이터셋으로 뉴럴 네트워크의 성능을 테스트합니다.

매개변수
  • dataloader (DataLoader) – 파이토치 데이터로더

  • device (_device) – 테스트에 사용되는 장치

  • model (nn.Module) – 테스트에 사용되는 모델

  • loss_fn (nn.Module) – 테스트에 사용되는 오차 함수

hnvlib.global_wheat_detection.train(dataloader: torch.utils.data.dataloader.DataLoader, device: str, model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer) None

Wheat 데이터셋으로 뉴럴 네트워크를 훈련합니다.

매개변수
  • dataloader (DataLoader) – 파이토치 데이터로더

  • device (str) – 훈련에 사용되는 장치

  • model (nn.Module) – 훈련에 사용되는 모델

  • optimizer (torch.optim.Optimizer) – 훈련에 사용되는 옵티마이저

hnvlib.global_wheat_detection.visualize_dataset(image_dir: os.PathLike, csv_path: os.PathLike, save_dir: os.PathLike, n_images: int = 10) None

데이터셋 샘플 bbox 그려서 시각화

매개변수

save_dir (os.PathLike) – bbox 그린 그림 저장할 폴더 경로

hnvlib.global_wheat_detection.visualize_predictions(testset: torch.utils.data.dataset.Dataset, device: str, model: torch.nn.modules.module.Module, save_dir: os.PathLike, conf_thr: float = 0.1, n_images: int = 10) None

이미지에 bbox 그려서 저장 및 시각화

매개변수
  • testset (Dataset) – 추론에 사용되는 데이터셋

  • device (str) – 추론에 사용되는 장치

  • model (nn.Module) – 추론에 사용되는 모델

  • save_dir (os.PathLike) – 추론한 사진이 저장되는 경로

  • conf_thr (float) – confidence threshold - 해당 숫자에 만족하지 않는 bounding box 걸러내는 파라미터