import torch import torch.nn as nn import torchvision.models as models import torch.nn.functional as F def get_resnet50(num_classes): model = models.resnet50(weights=None) model.fc = nn.Linear(model.fc.in_features, num_classes) return model class ResNet50(nn.Module): def __init__(self, num_classes): super(ResNet50, self).__init__() self.model = models.resnet50(weights=None) self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) def forward(self, x): return self.model(x)