Skip to content
Snippets Groups Projects
Commit e15f3e81 authored by hannandarryl's avatar hannandarryl
Browse files

mobile push server side

parent ae4847be
Branches
Tags
No related merge requests found
...@@ -54,8 +54,8 @@ if __name__ == "__main__": ...@@ -54,8 +54,8 @@ if __name__ == "__main__":
mobile_model = NetTensorFlowWrapper(args.sparse_checkpoint, None) mobile_model = NetTensorFlowWrapper(args.sparse_checkpoint, None)
for param in mobile_model.sparse_model.parameters(): # for param in mobile_model.sparse_model.parameters():
param.requires_grad = False # param.requires_grad = False
mobile_model.to(device) mobile_model.to(device)
...@@ -107,7 +107,7 @@ if __name__ == "__main__": ...@@ -107,7 +107,7 @@ if __name__ == "__main__":
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
if args.train: if args.train:
prediction_optimizer = torch.optim.Adam(mobile_model.parameters(), prediction_optimizer = torch.optim.Adam(mobile_model.predictive_model.parameters(),
lr=args.lr) lr=args.lr)
for epoch in range(args.epochs): for epoch in range(args.epochs):
......
...@@ -411,7 +411,7 @@ class NetTensorFlowWrapper(nn.Module): ...@@ -411,7 +411,7 @@ class NetTensorFlowWrapper(nn.Module):
rectifier=True, rectifier=True,
lam=0.05, lam=0.05,
max_activation_iter=200, max_activation_iter=200,
activation_lr=1e-1) activation_lr=1e-2)
self.predictive_model = SmallDataClassifierConv3d() self.predictive_model = SmallDataClassifierConv3d()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment