From e15f3e81579da67dfceb21f54da61eddcde36575 Mon Sep 17 00:00:00 2001 From: hannandarryl <hannandarryl@gmail.com> Date: Wed, 19 Jan 2022 16:35:54 +0000 Subject: [PATCH] mobile push server side --- mobile_app/train_mobile.py | 6 +++--- sparse_coding_torch/mobile_model.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mobile_app/train_mobile.py b/mobile_app/train_mobile.py index 59dc628..487ed79 100644 --- a/mobile_app/train_mobile.py +++ b/mobile_app/train_mobile.py @@ -54,8 +54,8 @@ if __name__ == "__main__": mobile_model = NetTensorFlowWrapper(args.sparse_checkpoint, None) - for param in mobile_model.sparse_model.parameters(): - param.requires_grad = False +# for param in mobile_model.sparse_model.parameters(): +# param.requires_grad = False mobile_model.to(device) @@ -107,7 +107,7 @@ if __name__ == "__main__": criterion = torch.nn.BCEWithLogitsLoss() if args.train: - prediction_optimizer = torch.optim.Adam(mobile_model.parameters(), + prediction_optimizer = torch.optim.Adam(mobile_model.predictive_model.parameters(), lr=args.lr) for epoch in range(args.epochs): diff --git a/sparse_coding_torch/mobile_model.py b/sparse_coding_torch/mobile_model.py index 9b6e18d..09476b1 100644 --- a/sparse_coding_torch/mobile_model.py +++ b/sparse_coding_torch/mobile_model.py @@ -411,7 +411,7 @@ class NetTensorFlowWrapper(nn.Module): rectifier=True, lam=0.05, max_activation_iter=200, - activation_lr=1e-1) + activation_lr=1e-2) self.predictive_model = SmallDataClassifierConv3d() -- GitLab