diff --git a/diagnosis_service/svc/prediction_listener.py b/diagnosis_service/svc/prediction_listener.py
index 76caeff78fc367fe1dca2da4ad4501dbeb1364b9..b7ac6633139bb6e8f71f6b1be873b1cfb84126b1 100644
--- a/diagnosis_service/svc/prediction_listener.py
+++ b/diagnosis_service/svc/prediction_listener.py
@@ -1,8 +1,10 @@
-import contextlib
-import json
+
import os
-import random
+import json
import time
+import random
+import contextlib
+import traceback
from itertools import islice
@@ -48,50 +50,7 @@ class DiagnosisRunner:
stop = int(img.number_of_regions() * stop_percentage)
print(f'Processing chips {start} - {stop}')
- return islice(img, start, stop)
-
- def do_diagnosis(
- self,
- image,
- db,
- request_id,
- process_id,
- start_percentage,
- stop_percentage
- ):
- """ converts filepath to region stream, then adds diagnosis to status """
- region_stream = self.make_region_stream(image, start_percentage, stop_percentage)
-
- diagnosis = self.model.diagnose(
- region_stream,
- total_chips=image.number_of_regions(),
- db=db,
- request_id=request_id,
- process_id=process_id
- )
- return diagnosis
-
-
-model_manager_for_web_app.config.DEFAULT_MODELS_DIR = '/home/ec2-user/models/wrapped_models'
-MODEL_NAME = 'zrahn_efficientnet_0219_full_parallel'
-DIAGNOSE_RUNNER = DiagnosisRunner(MODEL_NAME)
-PARALLEL = False
-
-
-class DiagnosisRunner:
- def __init__(self, model_name) -> None:
- """ model_name is a model handled by ModelManagerForWebApp"""
- self.model_name = model_name
- self.model = ModelManager().load_model(self.model_name)
-
- def make_region_stream(self, img, start_percentage, stop_percentage):
- """ break up iterator for parallelization """
-
- start = int(img.number_of_regions() * start_percentage)
- stop = int(img.number_of_regions() * stop_percentage)
- print(f'Processing chips {start} - {stop}')
-
- return islice(img, start, stop)
+ return islice(img, 20000, 40000)
def do_diagnosis(
self,
@@ -111,6 +70,10 @@ class DiagnosisRunner:
db=db,
request_id=request_id,
process_id=process_id,
+ save_moderate = False,
+ save_severe = True,
+ save_bucket = 'digpath-predictions',
+ save_prefix = f'{request_id}/save_severe',
parallel=PARALLEL
)
return diagnosis
@@ -119,6 +82,7 @@ class DiagnosisRunner:
model_manager_for_web_app.config.DEFAULT_MODELS_DIR = '/home/ec2-user/models/wrapped_models'
MODEL_NAME = 'zrahn_efficientnet_0219_full_parallel'
DIAGNOSE_RUNNER = DiagnosisRunner(MODEL_NAME)
+PARALLEL = True
def process_request(message_data):
try:
@@ -138,16 +102,16 @@ def process_request(message_data):
if request_info['status'] == 'complete':
print(f"Deleting {image.filepath}")
- #os.remove(image.filepath)
+ os.remove(image.filepath)
# results_dir = '/home/ec2-user/data/results'
# with open(f"{results_dir}/{MODEL_NAME}_{message_data['file']}_diagnosis.json", 'w') as f:
- # json.dump(diagnosis, f, indent=4)
+ # json.dump(diagnosis_results, f, indent=4)
return diagnosis_results
- except Exception as e:
- print(e)
+ except:
+ traceback.print_exc()
DIGPATH_DB.update_request(
message_data['request_id'],
None,
@@ -157,6 +121,7 @@ def process_request(message_data):
0,
"error"
)
+ os.remove(message_data['file'])
return {}
if __name__ == "__main__":
@@ -166,14 +131,14 @@ if __name__ == "__main__":
ray.init()
while True:
- messages = QUEUE.receive_messages()
- for message in messages:
- try:
+ try:
+ messages = QUEUE.receive_messages()
+ for message in messages:
message_body = message.body
message.delete()
print(f"processing message: {message_body}")
message_data = json.loads(message_body)
results = process_request(message_data)
- except Exception as e:
- print(e)
+ except Exception as e:
+ print(e)
diff --git a/ml/put_model_into_webapp.ipynb b/ml/put_model_into_webapp.ipynb
index 9bbaf86dafb3a086277978616bca01413f8892d5..c0602574fc8b99bec755b89dde3c5253852f4920 100644
--- a/ml/put_model_into_webapp.ipynb
+++ b/ml/put_model_into_webapp.ipynb
@@ -15,6 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
+ "import io\n",
"import os\n",
"import time\n",
"import uuid\n",
@@ -22,6 +23,7 @@
"\n",
"import ray\n",
"import PIL\n",
+ "import boto3\n",
"import torch\n",
"import numpy as np\n",
"\n",
@@ -86,25 +88,57 @@
"metadata": {},
"outputs": [],
"source": [
- "def predict_region(model, region, filtration, save_moderate, save_severe, save_dir):\n",
+ "def predict_region(\n",
+ " model,\n",
+ " region,\n",
+ " filtration,\n",
+ " save_moderate,\n",
+ " save_severe,\n",
+ " save_bucket,\n",
+ " save_prefix\n",
+ "):\n",
" if filtration(region) is True:\n",
" region_probs = model.diagnose_region(region)\n",
" region_probs = softmax(region_probs)\n",
"\n",
" #Save image\n",
- "# if region_probs[2] > 0.75 and save_severe:\n",
- "# im = Image.fromarray(region)\n",
- "# #TODO: send to S3\n",
- "# #im.save(f\"{save_dir}/{uuid.uuid4()}_severe_prediction.png\")\n",
- "# del im\n",
+ " if region_probs[1] > 0.85 and save_moderate:\n",
+ " s3 = boto3.client(\"s3\")\n",
+ " img = Image.fromarray(region)\n",
+ " in_mem_file = io.BytesIO()\n",
+ " img.save(in_mem_file, format='png')\n",
+ " in_mem_file.seek(0)\n",
+ " s3.upload_fileobj(in_mem_file, save_bucket, f\"{save_prefix}/{uuid.uuid4()}.png\")\n",
+ " del img\n",
+ " del in_mem_file\n",
+ "\n",
+ " #Save image\n",
+ " if region_probs[2] > 0.85 and save_severe:\n",
+ " s3 = boto3.client(\"s3\")\n",
+ " img = Image.fromarray(region)\n",
+ " in_mem_file = io.BytesIO()\n",
+ " img.save(in_mem_file, format='png')\n",
+ " in_mem_file.seek(0)\n",
+ " s3.upload_fileobj(in_mem_file, save_bucket, f\"{save_prefix}/{uuid.uuid4()}.png\")\n",
+ " del img\n",
+ " del in_mem_file\n",
"\n",
" del region\n",
"\n",
" return region_probs\n",
"\n",
"@ray.remote\n",
- "def predict_region_parallel(predict_fuction, model, region, filtration, save_moderate, save_severe, save_dir):\n",
- " return predict_fuction(model, region, filtration, save_moderate, save_severe, save_dir)\n",
+ "def predict_region_parallel(\n",
+ " predict_fuction,\n",
+ " model,\n",
+ " region,\n",
+ " filtration,\n",
+ " save_moderate,\n",
+ " save_severe,\n",
+ " save_bucket,\n",
+ " save_prefix\n",
+ "):\n",
+ " return predict_fuction(model, region, filtration, save_moderate, save_severe, save_bucket, save_prefix)\n",
"\n",
"\n",
"class WrappedModel(ManagedModel):\n",
@@ -139,7 +173,8 @@
" moderate_chip_thresh = 100,\n",
" save_moderate = False,\n",
" save_severe = False,\n",
- " save_dir='./',\n",
+ " save_bucket = 'digpath-predictions',\n",
+ " save_prefix = '',\n",
" parallel=False\n",
" ):\n",
" \"\"\"\n",
@@ -159,9 +194,6 @@
" self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" self.model.model.to(self._device)\n",
"\n",
- " if request_id is not None:\n",
- " save_dir += request_id + '/'\n",
- "\n",
" # initialize the votes\n",
" votes = {i:0 for i in range(len(self.classes))}\n",
"\n",
@@ -174,8 +206,9 @@
"\n",
" parallel_batch_size = 500\n",
" for region_idx, region in enumerate(region_stream):\n",
+ "\n",
+ " #model predictions\n",
" if parallel:\n",
- " #parallel model predictions\n",
" futures.append(predict_region_parallel.remote(\n",
" predict_function_id,\n",
" model_id,\n",
@@ -183,7 +216,8 @@
" filtration_id,\n",
" save_moderate,\n",
" save_severe,\n",
- " save_dir\n",
+ " save_bucket,\n",
+ " save_prefix\n",
" ))\n",
" else:\n",
" futures.append(predict_region(\n",
@@ -192,9 +226,9 @@
" self.filtration,\n",
" save_moderate,\n",
" save_severe,\n",
- " save_dir\n",
+ " save_bucket,\n",
+ " save_prefix\n",
" ))\n",
- "\n",
" del region\n",
"\n",
" # Aggregate batch\n",
@@ -231,6 +265,7 @@
" 'in progress'\n",
" )\n",
"\n",
+ " # Display progress\n",
" mild = votes[self.classes.index('Mild')]\n",
" moderate = votes[self.classes.index('Moderate')]\n",
" severe = votes[self.classes.index('Severe')]\n",
diff --git a/ml/test_e2e.py b/ml/test_e2e.py
index 96396c808fe09ca781fe36172b83f3c27f081079..578c512c3b1479cb92b4c9c91350a6cb767a12d0 100644
--- a/ml/test_e2e.py
+++ b/ml/test_e2e.py
@@ -18,7 +18,7 @@ def process_file(s3_file):
status = 'in progress'
while status not in ('complete', 'error'):
- response = requests.get('http://localhost:5000/request_info', json={"request_id": request_id})
+ response = requests.get(f'http://localhost:5000/request_info/{request_id}')
status = response.json()['status']
time.sleep(20)
@@ -31,8 +31,10 @@ s3_client = boto3.client('s3')
s3_resource = boto3.resource('s3')
image_bucket = s3_resource.Bucket(image_s3_bucket_name)
-for label in ['Moderate', 'Mild', 'Severe']:
- for s3_obj in image_bucket.objects.filter(Prefix=f'{label}/'):
+for label in ['Mild']: # ['Moderate', 'Mild', 'Severe']:
+ s3_files = list(image_bucket.objects.filter(Prefix=f'{label}/'))
+ s3_files.reverse()
+ for s3_obj in s3_files:
if s3_obj.key == f'{label}/':
continue