diff --git a/mobile_app/train_mobile.py b/mobile_app/train_mobile.py index 59dc628fe34252e6ee872dbac28d3cd70a3fef10..487ed79f07e57c8dc32483963446747ef366c80e 100644 --- a/mobile_app/train_mobile.py +++ b/mobile_app/train_mobile.py @@ -54,8 +54,8 @@ if __name__ == "__main__": mobile_model = NetTensorFlowWrapper(args.sparse_checkpoint, None) - for param in mobile_model.sparse_model.parameters(): - param.requires_grad = False +# for param in mobile_model.sparse_model.parameters(): +# param.requires_grad = False mobile_model.to(device) @@ -107,7 +107,7 @@ if __name__ == "__main__": criterion = torch.nn.BCEWithLogitsLoss() if args.train: - prediction_optimizer = torch.optim.Adam(mobile_model.parameters(), + prediction_optimizer = torch.optim.Adam(mobile_model.predictive_model.parameters(), lr=args.lr) for epoch in range(args.epochs): diff --git a/sparse_coding_torch/mobile_model.py b/sparse_coding_torch/mobile_model.py index 9b6e18d8b9a0ecba986364ce562795d11dc3611d..09476b1d2e72847dd053f6a3e120df66c2656f69 100644 --- a/sparse_coding_torch/mobile_model.py +++ b/sparse_coding_torch/mobile_model.py @@ -411,7 +411,7 @@ class NetTensorFlowWrapper(nn.Module): rectifier=True, lam=0.05, max_activation_iter=200, - activation_lr=1e-1) + activation_lr=1e-2) self.predictive_model = SmallDataClassifierConv3d()