import torch
from torch import nn
from superduper_torch.model import TorchModel
from superduper_torch.training import TorchTrainer
from torch.nn.functional import cross_entropy
class SimpleModel(nn.Module):
def __init__(self, input_size=16, hidden_size=32, num_classes=3):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
preprocess = lambda x: torch.tensor(x)
def postprocess(x):
return int(x.topk(1)[1].item())
def data_transform(features, label):
return torch.tensor(features), label
model = SimpleModel(feature_size, num_classes=num_classes)
model = TorchModel(
identifier='my-model',
object=model,
preprocess=preprocess,
postprocess=postprocess,
trainer=TorchTrainer(
key=(input_key, 'label'),
identifier='my_trainer',
objective=cross_entropy,
loader_kwargs={'batch_size': 10},
max_iterations=1000,
validation_interval=100,
select=select,
transform=data_transform,
),
)