Transfer learning
APPLY = False
COLLECTION_NAME = '<var:table_name>' if not APPLY else 'sample_transfer_learning'
MODALITY = 'text'
from superduper import superduper, CFG
db = superduper('mongomock://test_db')
Get useful sample data​
def getter(modality='text'):
import json
import subprocess
if modality == 'text':
subprocess.run([
'curl', '-O', 'https://superduperdb-public-demo.s3.amazonaws.com/text_classification.json',
])
with open("text_classification.json", "r") as f:
data = json.load(f)
subprocess.run(['rm', 'text_classification.json'])
data = data[:200]
else:
subprocess.run([
'curl', '-O', 'https://superduperdb-public-demo.s3.amazonaws.com/images_classification.zip',
])
subprocess.run(['unzip', 'images_classification.zip'])
subprocess.run(['rm', 'images_classification.zip'])
import json
from PIL import Image
with open('images/images.json', 'r') as f:
data = json.load(f)
data = [{'x': Image.open(d['image_path']), 'y': d['label']} for d in data]
return data
After obtaining the data, we insert it into the database.
Insert simple data​
After turning on auto_schema, we can directly insert data, and superduper will automatically analyze the data type, and match the construction of the table and datatype.
if APPLY:
from superduper import Document
ids = db[COLLECTION_NAME].insert([Document(r) for r in data]).execute()
Compute features​
import sentence_transformers
from superduper import vector, Listener
from superduper_sentence_transformers import SentenceTransformer
superdupermodel_text = SentenceTransformer(
identifier="embedding",
model='all-MiniLM-L6-v2',
postprocess=lambda x: x.tolist(),
)
import torchvision
from torchvision import transforms
from superduper_torch import TorchModel
from superduper import Listener, imported
from PIL import Image
class TorchVisionEmbedding:
def __init__(self):
self.resnet = models.resnet18(pretrained=True)
self.resnet.eval()
def preprocess(image):
preprocess = preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
tensor_image = preprocess(image)
return tensor_image
resnet = imported(torchvision.models.resnet18)(pretrained=True)
superdupermodel_image = TorchModel(
identifier='my-vision-model',
object=resnet,
preprocess=preprocess,
postprocess=lambda x: x.numpy().tolist()
)
from superduper.components.model import ModelRouter
feature_extractor = ModelRouter(
'feature_extractor',
models={
'text': superdupermodel_text,
'image': superdupermodel_image,
},
model=MODALITY,
)
feature_extractor_listener = Listener(
model=feature_extractor,
select=db[COLLECTION_NAME].select(),
key='x',
identifier="features"
)
if APPLY:
feature_extractor_listener = db.apply(
feature_extractor_listener,
force=True,
)
Build and train classifier​
from superduper_sklearn import Estimator, SklearnTrainer
from sklearn.svm import SVC
scikit_model = Estimator(
identifier="my-model-scikit",
object=SVC(),
trainer=SklearnTrainer(
"my-scikit-trainer",
key=(feature_extractor_listener.outputs, "y"),
select=db[COLLECTION_NAME].outputs(feature_extractor_listener.predict_id),
),
)
import torch
from torch import nn
from superduper_torch.model import TorchModel
from superduper_torch.training import TorchTrainer
from torch.nn.functional import cross_entropy
class SimpleModel(nn.Module):
def __init__(self, input_size=16, hidden_size=32, num_classes=2):
super(SimpleModel, self).__init__()
self.hidden_size = hidden_size
self.fc1 = None
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
input_size = x.size(1)
if self.fc1 is None:
self.fc1 = nn.Linear(input_size, self.hidden_size)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
preprocess = lambda x: torch.tensor(x)
def postprocess(x):
return int(x.topk(1)[1].item())
def data_transform(features, label):
return torch.tensor(features), label
model = SimpleModel(num_classes=2)
torch_model = TorchModel(
identifier='my-model-torch',
object=model,
preprocess=preprocess,
postprocess=postprocess,
trainer=TorchTrainer(
key=(feature_extractor_listener.outputs, 'y'),
identifier='my-torch-trainer',
objective=cross_entropy,
loader_kwargs={'batch_size': 10},
max_iterations=1000,
validation_interval=100,
select=db[COLLECTION_NAME].outputs(feature_extractor_listener.predict_id),
transform=data_transform,
),
)
Define a validation for evaluating the effect after training.
from superduper import Dataset, Metric, Validation
def acc(x, y):
return sum([xx == yy for xx, yy in zip(x, y)]) / len(x)
accuracy = Metric(identifier="acc", object=acc)
validation = Validation(
"transfer_learning_performance",
key=(feature_extractor_listener.outputs, "y"),
datasets=[
Dataset(
identifier="my-valid",
select=db[COLLECTION_NAME].outputs(feature_extractor_listener.predict_id).add_fold('valid')
)
],
metrics=[accuracy],
)
scikit_model.validation = validation
torch_model.validation = validation
If we execute the apply function, then the model will be added to the database, and because the model has a Trainer, it will perform training tasks.
estimator = ModelRouter(
'estimator',
models={
'scikit-framework': scikit_model,
'torch-framework': torch_model,
},
model='scikit-framework',
)
if APPLY:
db.apply(estimator, force=True)
Get the training metrics
if APPLY:
model = db.load('model', 'my-model-scikit')
model.metric_values
from superduper import Application
application = Application(
identifier='transfer-learning',
components=[feature_extractor_listener, estimator],
)
from superduper import Template, Table, Schema
from superduper.components.dataset import RemoteData
t = Template(
'transfer_learning',
default_table=Table(
'sample_transfer_learning',
schema=Schema(
'sample_transfer_learning/schema',
fields={'x': 'str', 'y': 'int'},
),
data=RemoteData(
'text_classification',
getter=getter,
),
),
template=application,
substitutions={'docs': 'table_name', 'text': 'modality'},
template_variables=['table_name', 'framework', 'modality'],
types={
'table_name': {
'type': 'str',
'default': 'sample_transfer_learning',
},
'modality': {
'type': 'str',
'default': 'text',
},
'framework': {
'type': 'str',
'default': 'scikit-framework',
},
}
)
t.export('.')