import torch import lightning as L import torchmetrics class LightningModel(L.LightningModule): def __init__(self, model, learning_rate, cosine_t_max, mode): super().__init__() self.learning_rate = learning_rate self.cosine_t_max = cosine_t_max self.model = model self.example_input_array = torch.Tensor(1, 3, 32, 32) self.mode = mode self.save_hyperparameters(ignore=["model"]) self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) def forward(self, x): return self.model(x) def _shared_step(self, batch): features, true_labels = batch logits = self(features) loss = F.cross_entropy(logits, true_labels) predicted_labels = torch.argmax(logits, dim=1) return loss, true_labels, predicted_labels def training_step(self, batch, batch_idx): loss, true_labels, predicted_labels = self._shared_step(batch) self.log("train_loss", loss) self.train_acc(predicted_labels, true_labels) self.log( "train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False ) return loss def validation_step(self, batch, batch_idx): loss, true_labels, predicted_labels = self._shared_step(batch) self.log("val_loss", loss, prog_bar=True) self.val_acc(predicted_labels, true_labels) self.log("val_acc", self.val_acc, prog_bar=True) def test_step(self, batch, batch_idx): loss, true_labels, predicted_labels = self._shared_step(batch) self.test_acc(predicted_labels, true_labels) self.log("test_acc", self.test_acc) def configure_optimizers(self): opt = torch.optim.SGD(self.parameters(), lr=self.learning_rate) if self.mode == 'lrfind': return opt else: sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.cosine_t_max) # New! return { "optimizer": opt, "lr_scheduler": { "scheduler": sch, "monitor": "train_loss", "interval": "step", # step means "batch" here, default: epoch "frequency": 1, # default }, }