diff --git a/.gitignore b/.gitignore
index 6f6b0c69f7a7b59bfb163805c2d9dab451c28817..c3e0946bd2ac98dcb17af25068d0be4b5bc5c40b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
venv/
__pycache__
-.vscode/
\ No newline at end of file
+.vscode/
+.ipynb_checkpoints
+diagnosis_service/db
\ No newline at end of file
diff --git a/diagnosis_service/models/zrahn_efficientnet_0129.model b/diagnosis_service/models/zrahn_efficientnet_0129.model
new file mode 100644
index 0000000000000000000000000000000000000000..e3fb4d8d2c631d49029515d767f8b3f6ce563ae2
Binary files /dev/null and b/diagnosis_service/models/zrahn_efficientnet_0129.model differ
diff --git a/diagnosis_service/models/zrahn_efficientnet_0129.model_info b/diagnosis_service/models/zrahn_efficientnet_0129.model_info
new file mode 100644
index 0000000000000000000000000000000000000000..d63f5d552e9e7c3c12aa9cb4d8bcfdd72c4cc1cd
--- /dev/null
+++ b/diagnosis_service/models/zrahn_efficientnet_0129.model_info
@@ -0,0 +1,3 @@
+{
+ "info": "pretrained efficientnet fine-tuned 01/29/2023 on chips (including simulated)"
+}
\ No newline at end of file
diff --git a/diagnosis_service/setup/setup.py b/diagnosis_service/setup/setup.py
index 3c6b9a91c054cf7f62a85fe2e8598db62942fb10..9d9e67ed1149e43617626da816cb744f5f0ad5bf 100644
--- a/diagnosis_service/setup/setup.py
+++ b/diagnosis_service/setup/setup.py
@@ -1,28 +1,53 @@
-import sqlite3
-
-# connecting to the database
-connection = sqlite3.connect("../db/digpath.db")
-
+#MySQL server vs SQL database
+SQL_SERVER = False
+
+if SQL_SERVER:
+ import mysql.connector
+
+ DB_HOST = "localhost"
+ DB_USER = "digpath"
+ DB_PASS = "password"
+ DB_NAME = "digpath"
+ connection = mysql.connector.connect(host=DB_HOST, user=DB_USER, password=DB_PASS, database=DB_NAME)
+else:
+ import sqlite3
+
+ DB_FILE = "/home/ec2-user/db/digpath.db"
+ connection = sqlite3.connect(DB_FILE)
+
# cursor
crsr = connection.cursor()
-
+
# SQL command to create a table in the database
sql_command = """
CREATE TABLE requests (
- request_id TEXT(255) PRIMARY KEY,
- file TEXT(255),
- diagnosis TEXT(255),
+ request_id VARCHAR(255) PRIMARY KEY,
+ file VARCHAR(255),
+ diagnosis VARCHAR(255),
total_chips INTEGER,
+ status VARCHAR(255),
+ timestamp VARCHAR(255),
+ last_update VARCHAR(255)
+ );
+"""
+
+# execute the statement
+crsr.execute(sql_command)
+
+# SQL command to create a table in the database
+sql_command = """
+ CREATE TABLE processing (
+ process VARCHAR(255) PRIMARY KEY,
+ request_id VARCHAR(255),
mild INTEGER,
moderate INTEGER,
severe INTEGER,
- status TEXT(255),
- timestamp TEXT(255)
+ status VARCHAR(255)
);
"""
-
+
# execute the statement
crsr.execute(sql_command)
-
+
# close the connection
-connection.close()
\ No newline at end of file
+connection.close()
diff --git a/diagnosis_service/svc/database_connection.py b/diagnosis_service/svc/database_connection.py
index 42346c944196396d25bba4cb819ce9710e170540..4b64b94df2b4249e2af34da7d1d60c71de01fa60 100644
--- a/diagnosis_service/svc/database_connection.py
+++ b/diagnosis_service/svc/database_connection.py
@@ -4,63 +4,142 @@ from datetime import datetime
import sqlite3
+SEVERE_THRESHOLD = 12
+MODERATE_THRESHOLD = 50
class DigpathDatabase:
def __init__(self, connection):
self._db = connection
- def new_request(self, file, request_id=None):
- if request_id is None:
- request_id = str(uuid.uuid4())
-
+ def new_request(self, file, request_id, processes):
cur = self._db.cursor()
cur.execute(
- 'INSERT INTO requests (request_id, file, total_chips, mild, moderate, severe, status, timestamp) VALUES(?,?,?,?,?,?,?,?)',
- (request_id, file, 50000, 0, 0, 0, 'in progress', datetime.now().strftime("%Y-%m-%dT%H:%M:%S"))
+ "INSERT INTO requests (request_id, file, total_chips, status, timestamp) \
+ VALUES('%s','%s',%s,'%s','%s');" % (
+ request_id, file, 0, 'in progress', datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
+ )
)
self._db.commit()
- return cur.lastrowid
+
+ for process_id in processes:
+ cur.execute(
+ "INSERT INTO processing (process, request_id, mild, moderate, severe, status) \
+ VALUES('%s','%s',%s,%s,%s,'%s');" % (process_id, request_id, 0, 0, 0, 'submitted')
+ )
+ self._db.commit()
def update_request(
self,
request_id,
- diagnosis,
+ process_id,
+ total_chips,
mild,
moderate,
severe,
status,
):
+ #Update requests table
query = 'UPDATE requests SET '
- if diagnosis is not None:
- query += 'diagnosis="%s", ' % (diagnosis)
+ if total_chips is not None:
+ query += ' total_chips="%s",' % (total_chips)
- query += 'mild=%s, moderate=%s, severe=%s, status="%s" WHERE request_id="%s";' % (
- mild, moderate, severe, status, request_id
+ query += ' last_update="%s" WHERE request_id="%s";' % (
+ datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
+ request_id
)
self._db.cursor().execute(query, ())
self._db.commit()
+ #Update processing table
+ if process_id is not None:
+ query = 'UPDATE processing \
+ SET mild=%s, moderate=%s, severe=%s, status="%s" \
+ WHERE process="%s";' % (
+ mild, moderate, severe, status, process_id
+ )
+
+ self._db.cursor().execute(query, ())
+ self._db.commit()
+
+ #Checking if the overall request is complete
+ db_cursor = self._db.cursor()
+ db_cursor.execute(
+ 'SELECT process, status, mild, moderate, severe \
+ FROM processing WHERE request_id = "%s";' % (request_id)
+ )
+ response = db_cursor.fetchall()
+
+ #If the request is complete, set the status and diagnosis
+ moderate = 0
+ severe = 0
+ request_complete = True
+ for process_info in response:
+ moderate += process_info[2]
+ severe += process_info[3]
+ if process_info[1] != 'complete':
+ request_complete = False
+
+ if request_complete:
+ diagnosis = 'Mild'
+ if severe > SEVERE_THRESHOLD:
+ diagnosis = 'Severe'
+ elif moderate > MODERATE_THRESHOLD:
+ diagnosis = 'Moderate'
+
+ query = 'UPDATE requests SET diagnosis="%s", status="%s"\
+ WHERE request_id="%s";' % (diagnosis, 'complete', request_id)
+ self._db.cursor().execute(query, ())
+ self._db.commit()
+ elif status == 'error':
+ query = 'UPDATE requests SET status="%s" WHERE request_id="%s";' % ('error', request_id)
+ self._db.cursor().execute(query, ())
+ self._db.commit()
+
+
def get_request_info(self, request_id):
- cursor = self._db.cursor().execute('SELECT * FROM requests WHERE request_id = "%s";' % (request_id))
- response = cursor.fetchall()
+ #Get information about the request
+ db_cursor = self._db.cursor()
+ db_cursor.execute('SELECT * FROM requests WHERE request_id = "%s";' % (request_id))
+ response = db_cursor.fetchall()
if len(response) > 0:
request_info = response[0]
else:
print("Request not found")
+ #Get the processing data for the request
+ db_cursor = self._db.cursor()
+ db_cursor.execute('SELECT * FROM processing WHERE request_id = "%s";' % (request_id))
+ response = db_cursor.fetchall()
+
+ complete_processes = []
+ #Aggregate the results
+ chip_predictions = {'mild': 0, 'moderate': 0, 'severe': 0}
+ if len(response) > 0:
+ for process_info in response:
+ chip_predictions['mild'] += process_info[2]
+ chip_predictions['moderate'] += process_info[3]
+ chip_predictions['severe'] += process_info[4]
+ complete_processes.append(process_info[5] == 'complete')
+ else:
+ print("No processing data in this request")
+
+ status = request_info[4]
+ if len(complete_processes) > 0 and all(complete_processes):
+ status = 'complete'
+
results = {
'request_id': request_info[0],
'file': request_info[1],
'diagnosis': request_info[2],
'total_chips': request_info[3],
- 'mild': request_info[4],
- 'moderate': request_info[5],
- 'severe': request_info[6],
- 'status': request_info[7],
- 'timestamp': request_info[8]
+ 'mild': chip_predictions['mild'],
+ 'moderate': chip_predictions['moderate'],
+ 'severe': chip_predictions['severe'],
+ 'status': status,
+ 'timestamp': request_info[5],
+ 'last_update': request_info[6],
}
return results
-
diff --git a/diagnosis_service/svc/prediction_listener.py b/diagnosis_service/svc/prediction_listener.py
index e53977bedadf9072c452c580f3ba47c8a1fd474b..76caeff78fc367fe1dca2da4ad4501dbeb1364b9 100644
--- a/diagnosis_service/svc/prediction_listener.py
+++ b/diagnosis_service/svc/prediction_listener.py
@@ -1,68 +1,179 @@
+import contextlib
import json
+import os
import random
import time
+from itertools import islice
+
+import ray
import boto3
-import sqlite3
+import model_manager_for_web_app
+
from database_connection import DigpathDatabase
+from unified_image_reader import Image
+from model_manager_for_web_app import ModelManager
+
+DIAGNOSIS_RUNNER_STATUS_LOCK_TIMEOUT = 1
+
+QUEUE_NAME = "digpath-request"
+QUEUE = boto3.resource("sqs").get_queue_by_name(QueueName=QUEUE_NAME)
+
+SQL_SERVER = False
+if SQL_SERVER:
+ import mysql.connector
+ DB_HOST = "localhost"
+ DB_USER = "digpath"
+ DB_PASS = "password"
+ DB_NAME = "digpath"
+ DIGPATH_DB = DigpathDatabase(
+ mysql.connector.connect(host=DB_HOST, user=DB_USER, password=DB_PASS, database=DB_NAME)
+ )
+else:
+ import sqlite3
+ DB_FILE = "/home/ec2-user/db/digpath.db"
+ DIGPATH_DB = DigpathDatabase(sqlite3.connect(DB_FILE, check_same_thread=False))
+
+
+class DiagnosisRunner:
+ def __init__(self, model_name) -> None:
+ """ model_name is a model handled by ModelManagerForWebApp"""
+ self.model_name = model_name
+ self.model = ModelManager().load_model(self.model_name)
+
+ def make_region_stream(self, img, start_percentage, stop_percentage):
+ """ break up iterator for parallelization """
+
+ start = int(img.number_of_regions() * start_percentage)
+ stop = int(img.number_of_regions() * stop_percentage)
+ print(f'Processing chips {start} - {stop}')
+
+ return islice(img, start, stop)
+
+ def do_diagnosis(
+ self,
+ image,
+ db,
+ request_id,
+ process_id,
+ start_percentage,
+ stop_percentage
+ ):
+ """ converts filepath to region stream, then adds diagnosis to status """
+ region_stream = self.make_region_stream(image, start_percentage, stop_percentage)
+
+ diagnosis = self.model.diagnose(
+ region_stream,
+ total_chips=image.number_of_regions(),
+ db=db,
+ request_id=request_id,
+ process_id=process_id
+ )
+ return diagnosis
+
+
+model_manager_for_web_app.config.DEFAULT_MODELS_DIR = '/home/ec2-user/models/wrapped_models'
+MODEL_NAME = 'zrahn_efficientnet_0219_full_parallel'
+DIAGNOSE_RUNNER = DiagnosisRunner(MODEL_NAME)
+PARALLEL = False
+
-queue_name = "digpath-request"
-db_connection = sqlite3.connect("../db/digpath.db", check_same_thread=False)
-
-sqs = boto3.resource("sqs")
-queue = sqs.get_queue_by_name(QueueName=queue_name)
-digpath_db = DigpathDatabase(db_connection)
-
-def process_request(message_body):
- print(f"processing message: {message_body}")
- message_data = json.loads(message_body)
- request_info = digpath_db.get_request_info(message_data['request_id'])
-
- #TODO: replace the code below with the ML model predictions - below are fake ML predictions
- mild = 0
- moderate = 0
- severe = 0
- diagnosis = request_info['diagnosis']
- status = request_info['status']
- while status == 'in progress':
- if random.random() < 0.03:
- severe += random.randint(1, 10)
- elif random.random() < 0.12:
- moderate += random.randint(1, 20)
- mild += (1000 - severe - moderate)
-
- if severe > 15:
- diagnosis = 'severe'
- status = 'complete'
- elif moderate > 100:
- diagnosis = 'moderate'
- status = 'complete'
- elif mild + moderate + severe >= request_info['total_chips']:
- mild = request_info['total_chips'] - severe - moderate
- diagnosis = 'mild'
- status = 'complete'
-
- digpath_db.update_request(
- request_info['request_id'],
- diagnosis,
- mild,
- moderate,
- severe,
- status
+class DiagnosisRunner:
+ def __init__(self, model_name) -> None:
+ """ model_name is a model handled by ModelManagerForWebApp"""
+ self.model_name = model_name
+ self.model = ModelManager().load_model(self.model_name)
+
+ def make_region_stream(self, img, start_percentage, stop_percentage):
+ """ break up iterator for parallelization """
+
+ start = int(img.number_of_regions() * start_percentage)
+ stop = int(img.number_of_regions() * stop_percentage)
+ print(f'Processing chips {start} - {stop}')
+
+ return islice(img, start, stop)
+
+ def do_diagnosis(
+ self,
+ image,
+ db,
+ request_id,
+ process_id,
+ start_percentage,
+ stop_percentage
+ ):
+ """ converts filepath to region stream, then adds diagnosis to status """
+ region_stream = self.make_region_stream(image, start_percentage, stop_percentage)
+
+ diagnosis = self.model.diagnose(
+ region_stream,
+ total_chips=image.number_of_regions(),
+ db=db,
+ request_id=request_id,
+ process_id=process_id,
+ parallel=PARALLEL
+ )
+ return diagnosis
+
+
+model_manager_for_web_app.config.DEFAULT_MODELS_DIR = '/home/ec2-user/models/wrapped_models'
+MODEL_NAME = 'zrahn_efficientnet_0219_full_parallel'
+DIAGNOSE_RUNNER = DiagnosisRunner(MODEL_NAME)
+
+def process_request(message_data):
+ try:
+ image = Image(message_data['file'])
+ diagnosis_results = DIAGNOSE_RUNNER.do_diagnosis(
+ image,
+ DIGPATH_DB,
+ message_data['request_id'],
+ message_data['process_id'],
+ message_data['start_percentage'],
+ message_data['stop_percentage'],
)
- time.sleep(1)
- print(f"Processing of {message_data['request_id']} complete")
+ print(f"Request: {message_data['request_id']}, Process: {message_data['process_id']} complete")
+ request_info = DIGPATH_DB.get_request_info(message_data['request_id'])
+ print(json.dumps(request_info, indent=2))
+
+ if request_info['status'] == 'complete':
+ print(f"Deleting {image.filepath}")
+ #os.remove(image.filepath)
+
+ # results_dir = '/home/ec2-user/data/results'
+ # with open(f"{results_dir}/{MODEL_NAME}_{message_data['file']}_diagnosis.json", 'w') as f:
+ # json.dump(diagnosis, f, indent=4)
+
+ return diagnosis_results
+
+ except Exception as e:
+ print(e)
+ DIGPATH_DB.update_request(
+ message_data['request_id'],
+ None,
+ None,
+ 0,
+ 0,
+ 0,
+ "error"
+ )
+ return {}
if __name__ == "__main__":
- print(f"Diagnosis listener running. Listening for messages in the '{queue_name}' queue")
+ print(f"Diagnosis listener running. Listening for messages in the '{QUEUE_NAME}' queue")
+
+ if PARALLEL:
+ ray.init()
+
while True:
- messages = queue.receive_messages()
+ messages = QUEUE.receive_messages()
for message in messages:
try:
message_body = message.body
message.delete()
- process_request(message_body)
+ print(f"processing message: {message_body}")
+ message_data = json.loads(message_body)
+ results = process_request(message_data)
except Exception as e:
print(e)
diff --git a/diagnosis_service/svc/run.py b/diagnosis_service/svc/run.py
index 19a1c78af462019977fc0a7b805ded4d7ebd81b5..af895e47af6ed5b50b91171101f87c84dead9e3b 100644
--- a/diagnosis_service/svc/run.py
+++ b/diagnosis_service/svc/run.py
@@ -2,16 +2,43 @@ import json
import uuid
import boto3
-import sqlite3
+import mysql.connector
from flask import Flask, request, jsonify
from database_connection import DigpathDatabase
-queue_url = "https://sqs.us-east-1.amazonaws.com/432722299252/digpath-request"
-db_connection = sqlite3.connect("../db/digpath.db", check_same_thread=False)
+SLICE_BREAKUP = [1]
+
+QUEUE_URL = "https://sqs.us-east-1.amazonaws.com/432722299252/digpath-request"
+SQS = boto3.client('sqs')
+
+SQL_SERVER = False
+if SQL_SERVER:
+ import mysql.connector
+ DB_HOST = "localhost"
+ DB_USER = "digpath"
+ DB_PASS = "password"
+ DB_NAME = "digpath"
+ DIGPATH_DB = DigpathDatabase(
+ mysql.connector.connect(host=DB_HOST, user=DB_USER, password=DB_PASS, database=DB_NAME)
+ )
+else:
+ import sqlite3
+ DB_FILE = "/home/ec2-user/db/digpath.db"
+ DIGPATH_DB = DigpathDatabase(sqlite3.connect(DB_FILE, check_same_thread=False))
app = Flask(__name__)
-digpath_db = DigpathDatabase(db_connection)
-sqs = boto3.client('sqs')
+
+def get_slice_percentage(slice_num):
+ start = 0
+ for slice_percentage in SLICE_BREAKUP[:slice_num]:
+ start += slice_percentage
+
+ if (slice_num + 1) == len(SLICE_BREAKUP):
+ stop = 1
+ else:
+ stop = start + SLICE_BREAKUP[slice_num]
+
+ return start, stop
@app.route('/diagnose')
def diagnoseEndpoint():
@@ -21,16 +48,24 @@ def diagnoseEndpoint():
print(f"Request received: \nRequest ID: {request_id} \n{json.dumps(content, indent=2)}")
#Insert request into database
- digpath_db.new_request(content['file'], request_id)
+ processes = [str(uuid.uuid4()) for _ in range(len(SLICE_BREAKUP))]
+ DIGPATH_DB.new_request(content['file'], request_id, processes)
#Send SQS message
- sqs_response = sqs.send_message(
- QueueUrl=queue_url,
- MessageBody=json.dumps({
- 'request_id': request_id,
- 'file': content['file']
- }, indent=2)
- )
+ for slice_number, process in enumerate(processes):
+ slice_start, slice_stop = get_slice_percentage(slice_number)
+
+ sqs_response = SQS.send_message(
+ QueueUrl=QUEUE_URL,
+ MessageBody=json.dumps({
+ 'request_id': request_id,
+ 'process_id': process,
+ 'slice_number': slice_number,
+ 'start_percentage': slice_start,
+ 'stop_percentage': slice_stop,
+ 'file': content['file']
+ }, indent=2)
+ )
print(f'SQS Message: \n{sqs_response}')
response = {
@@ -42,9 +77,9 @@ def diagnoseEndpoint():
def getDiagnosisEndpoint():
content = request.json
- print(f"Getting info for: \n{json.dumps(content, indent=2)}")
+ print(f"Getting info for: {content}")
- request_info = digpath_db.get_request_info(content['request_id'])
+ request_info = DIGPATH_DB.get_request_info(content['request_id'])
return jsonify(request_info)
if __name__ == '__main__':
diff --git a/ml/Evaluate_Results.ipynb b/ml/Evaluate_Results.ipynb
index 63a81bb769a9473d82206e0d2d1b1422342a4c46..9f0cbc8c8470b1cfc9d0fbe8a39cc7e9f89254b7 100644
--- a/ml/Evaluate_Results.ipynb
+++ b/ml/Evaluate_Results.ipynb
@@ -809,7 +809,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.15"
+ "version": "3.10.8"
}
},
"nbformat": 4,
diff --git a/ml/ml_prediction.py b/ml/ml_prediction.py
index bbc2887de5fb06429a611509b50d4b9e9a71071c..80c042ce0a15319050763f437c49deb807f215ce 100644
--- a/ml/ml_prediction.py
+++ b/ml/ml_prediction.py
@@ -2,11 +2,12 @@ import os
import contextlib
import json
import glob
-import shutil
+import random
import sys
import threading
import uuid
+import boto3
import model_manager_for_web_app
from unified_image_reader import Image
from model_manager_for_web_app import ModelManager
@@ -60,13 +61,15 @@ class DiagnosisRunner:
self.update_status(current_region=region_num)
yield region
- def do_diagnosis(self, filepath):
+ def do_diagnosis(self, filepath, save_dir):
""" converts filepath to region stream, then adds diagnosis to status """
region_stream = self.make_region_stream(filepath)
diagnosis = self.model.diagnose(
region_stream,
- stream_length=Image(filepath).number_of_regions(),
+ total_chips=Image(filepath).number_of_regions(),
+ save_moderate=True,
save_severe=True,
+ save_dir=save_dir
)
self.update_status(diagnosis=diagnosis['vote'])
return diagnosis
@@ -83,85 +86,102 @@ class DiagnosisRunner:
self._status[key] = value
run_id = str(uuid.uuid4().hex)
-model_manager_for_web_app.config.DEFAULT_MODELS_DIR = '/home/ec2-user/efs/repos/models'
-model_name = 'zac_efficientnet_hp_sweep'
-imagedir = '/home/ec2-user/efs/data/images/digpath-data/Mild' #TODO: change to digpath-data-new & -batch-3
-#savedir = '/home/ec2-user/efs/data/labels/train'
-savedir = '/home/ec2-user/efs/data/labels/incorrect/actually_Mild'
-batch = int(sys.argv[1])
-files = glob.glob(f'{imagedir}/*')
-results_dir = '/home/ec2-user/efs/data/results'
-saved_tiles = glob.glob(f'{savedir}/*.png')
+model_manager_for_web_app.config.DEFAULT_MODELS_DIR = '/home/ec2-user/efs/models/wrapped_models'
+model_name = 'zrahn_efficientnet_0219_full'
+diagnose_runner = DiagnosisRunner(model_name)
+
+image_labels = 'Mild'
+
+image_s3_bucket_name = 'digpath-data' #TODO: digpath-data-new & -batch-3
+image_s3_prefix = f'{image_labels}/'
+chips_s3_bucket_name = 'digpath-chips'
+chips_s3_prefix = f'new/{image_labels}/'
+
+s3_client = boto3.client('s3')
+s3_resource = boto3.resource('s3')
+image_bucket = s3_resource.Bucket(image_s3_bucket_name)
+
+data_dir = '/home/ec2-user/data'
+imagedir = f'{data_dir}/images/{image_labels}'
+savedir = f'{data_dir}/chips/{image_labels}'
+
+results_dir = f'{data_dir}/results'
+
+batch = -1
+if len(sys.argv) >= 2:
+ batch = int(sys.argv[1])
+
+files = list(image_bucket.objects.all())
start = 0
end = len(files)
-print(f'Processing batch: {batch}')
+
if batch == 0:
start = 0
end = int(0.25 * len(files))
- #savedir = '/home/ec2-user/efs/data/labels/train'
elif batch == 1:
start = int(0.25 * len(files))
end = int(0.5 * len(files))
- #savedir = '/home/ec2-user/efs/data/labels/train'
elif batch == 2:
start = int(0.5 * len(files))
end = int(0.75 * len(files))
- #savedir = '/home/ec2-user/efs/data/labels/val'
elif batch == 3:
start = int(0.75 * len(files))
end = len(files)
- #savedir = '/home/ec2-user/efs/data/labels/test'
-files.sort()
-for img_filepath in files[start:end]:
+#files = sorted(files, key=lambda x: x.key)
+random.shuffle(files)
+
+for img_object in files[start:end]:
try:
results = {}
- img_filename = img_filepath.split('/')[-1]
- # for saved_file in saved_tiles:
- # if img_filename in saved_file:
- # continue
+ img_name = img_object.key
+ img_filename = img_name.split('/')[-1]
+
+ # Check if this file already has results
processed = False
- for saved_file in glob.glob(f'{results_dir}/*'):
- if img_filename in saved_file:
+ for processed_file in glob.glob(f'{results_dir}/*'):
+ if img_filename in processed_file:
processed = True
break
if processed:
continue
- print('Processing ' + img_filepath)
- temp_img_filename = f'/home/ec2-user/temp/{img_filename}'
- shutil.move(img_filepath, temp_img_filename)
- diagnose_runner = DiagnosisRunner(model_name)
- diagnosis = diagnose_runner.do_diagnosis(temp_img_filename)
-
+ # Upload any image chips from previous runs
+ for saved_file in glob.glob(f'{savedir}/*.png'):
+ s3_client.upload_file(
+ saved_file,
+ chips_s3_bucket_name,
+ f"{chips_s3_prefix}/{saved_file.split('/')[-1]}"
+ )
+ os.remove(saved_file)
+
+ print('Processing ' + img_name)
+ img_path = f'{imagedir}/{img_filename}'
+ s3_client.download_file(image_s3_bucket_name, img_name, img_path)
+ diagnosis = diagnose_runner.do_diagnosis(img_path, savedir)
+ diagnosis['label'] = image_labels
+ diagnosis['file'] = img_path
+
+ # Save results
results_filename = f'{results_dir}/{model_name}_{img_filename}_{run_id}_diagnosis.json'
with open(results_filename, 'w') as f:
json.dump(diagnosis, f, indent=4)
- shutil.move(temp_img_filename, img_filepath)
-
- # chip_files = glob.glob('/home/ec2-user/temp/Severe/*.png')
- # print(f'Saving {len(chip_files)} files')
- # for chip_filepath in chip_files:
- # chip_filename = chip_filepath.split('/')[-1]
- # shutil.move(chip_filepath, f'{savedir}/Severe/{chip_filename}')
-
- chip_files = glob.glob('/home/ec2-user/temp/actually_Mild/actually_Mild/*.png')
- print(f'Saving {len(chip_files)} files')
- for chip_filepath in chip_files:
- chip_filename = chip_filepath.split('/')[-1]
- shutil.move(chip_filepath, f'{savedir}/{chip_filename}')
+ # Delete the file
+ os.remove(img_path)
except Exception as e:
print(e)
try:
- if len(glob.glob(temp_img_filename)) == 1:
- shutil.move(temp_img_filename, img_filepath)
-
- # chip_files = glob.glob('/home/ec2-user/temp/Severe/*.png')
- # print(f'Saving {len(chip_files)} files')
- # for chip_filepath in chip_files:
- # chip_filename = chip_filepath.split('/')[-1]
- # shutil.move(chip_filepath, f'{savedir}/Severe/{chip_filename}')
+ if len(glob.glob(img_path)) == 1:
+ os.remove(img_path)
+ pass
+
+ for saved_file in glob.glob(f'{savedir}/*.png'):
+ s3_client.upload_file(
+ saved_file,
+ chips_s3_bucket_name,
+ f"{chips_s3_prefix}/{saved_file.split('/')[-1]}"
+ )
except Exception:
pass
diff --git a/ml/my_model.py b/ml/my_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74a087f2ad0ef394a052c535c53c7e92e57f55f
--- /dev/null
+++ b/ml/my_model.py
@@ -0,0 +1,221 @@
+from typing import Callable, Optional
+import torch
+import os
+import numpy as np
+
+from tqdm import tqdm
+from torch import nn
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+from torch.nn.parallel import DistributedDataParallel as DDP
+from sklearn.metrics import confusion_matrix
+
+from unified_image_reader import Image
+
+
+class MyModel:
+ """
+ _summary_
+ """
+
+ def __init__(self, model: nn.Module, loss_fn: nn.Module, device: str, checkpoint_dir: str, model_dir: str, optimizer: Optimizer):
+ """
+ __init__ _summary_
+
+ :param model: PyTorch model
+ :type model: nn.Module
+ :param loss_fn: PyTorch Loss Function
+ :type loss_fn: nn.Module
+ :param device: Device Type
+ :type device: str
+ :param checkpoint_dir: Filepath to checkpoint directory for mid train saving
+ :type checkpoint_dir: str
+ :param model_dir: Filepath to output directory for final model saving
+ :type model_dir: str
+ :param optimizer: PyTorch Optimization Function
+ :type optimizer: Optimizer
+ """
+ self.model = model
+ self.loss_fn = loss_fn
+ self.device = device
+ phases = ["train"]
+ num_classes = 3
+ self.all_acc = {key: 0 for key in phases}
+ self.all_loss = {
+ key: torch.zeros(0, dtype=torch.float64).to(device)
+ for key in phases
+ }
+ self.cmatrix = {key: np.zeros(
+ (num_classes, num_classes)) for key in phases}
+ self.model_dir = model_dir
+ self.checkpoint_dir = checkpoint_dir
+ self.optimizer = optimizer
+
+ def parallel(self, distributed: bool = False):
+ """
+ parallel Prepares model for distributed learning
+ :param distributed: Determines if distributed learning is occurring, defaults to False
+ :type distributed: bool, optional
+ """
+ if distributed:
+ self.model = DDP(self.model)
+ elif torch.cuda.device_count() > 1:
+ print(f"Gpu count: {torch.cuda.device_count()}")
+ self.model = nn.DataParallel(self.model)
+
+ def train_model(self, data_loader: DataLoader):
+ """
+ train_model Performs model training
+
+ :param data_loader: DataLoader of training set data
+ :type data_loader: DataLoader
+ """
+ self.all_loss['train'] = torch.zeros(
+ 0, dtype=torch.float64).to(self.device)
+ self.model.train()
+ for ii, (X, label) in enumerate(data_loader):
+ X = X.to(self.device)
+ label = label.type('torch.LongTensor').to(self.device)
+ with torch.set_grad_enabled(True):
+ prediction = self.model(X.permute(0, 3, 1,
+ 2).float()) # [N, Nclass]
+ loss = self.loss_fn(prediction, label)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ self.all_loss['train'] = torch.cat(
+ (self.all_loss['train'], loss.detach().view(1, -1)))
+ self.all_acc['train'] = (self.cmatrix['train'] /
+ (self.cmatrix['train'].sum() + 1e-6)).trace()
+ self.all_loss['train'] = self.all_loss['train'].cpu().numpy().mean()
+
+ def eval(self, data_loader: DataLoader, num_classes: int):
+ """
+ eval Performs model validation
+
+ :param data_loader: DataLoader of validation set data
+ :type data_loader: DataLoader
+ :param num_classes: Number of classes passed into the model
+ :type num_classes: int
+ """
+ self.model.eval()
+ self.all_loss['val'] = torch.zeros(
+ 0, dtype=torch.float64).to(self.device)
+ for ii, (X, label) in enumerate((pbar := tqdm(data_loader))):
+ pbar.set_description(f'validation_progress_{ii}', refresh=True)
+ X = X.to(self.device)
+ label = torch.tensor(list(map(int, label))).to(self.device)
+ with torch.no_grad():
+ prediction = self.model(X.permute(0, 3, 1,
+ 2).float()) # [N, Nclass]
+ loss = self.loss_fn(prediction, label)
+ p = prediction.detach().cpu().numpy()
+ cpredflat = np.argmax(p, axis=1).flatten()
+ yflat = label.cpu().numpy().flatten()
+ self.all_loss['val'] = torch.cat(
+ (self.all_loss['val'], loss.detach().view(1, -1)))
+ self.cmatrix['val'] = self.cmatrix['val'] + \
+ confusion_matrix(yflat, cpredflat,
+ labels=range(num_classes))
+ self.all_acc['val'] = (self.cmatrix['val'] /
+ self.cmatrix['val'].sum()).trace()
+ self.all_loss['val'] = self.all_loss['val'].cpu().numpy().mean()
+
+ def save_model(self, filepath: Optional[str] = None):
+ """
+ save_model Saves the model to a specific directory
+
+ :param filepath: path to output directory, defaults to None
+ :type filepath: Optional[str], optional
+ """
+ print("Saving the model.")
+ path = filepath or os.path.join(self.model_dir, 'model.pth')
+ # recommended way from http://pytorch.org/docs/master/notes/serialization.html
+ torch.save(self.model.cpu().state_dict(), path)
+
+ def save_checkpoint(self, state: dict):
+ """
+ save_checkpoint Saves the checkpoint to a specific directory
+
+ :param state: Dictionary of various values
+ :type state: dict
+ """
+ path = os.path.join(self.checkpoint_dir, 'checkpoint.pth')
+ print("Saving the Checkpoint: {}".format(path))
+ torch.save({
+ 'model_state_dict': self.model.state_dict(),
+ 'optimizer_state_dict': self.optimizer.state_dict(),
+ **state
+ }, path)
+
+ def load_checkpoint(self):
+ """
+ load_checkpoint Loads the checkpoint from a specific directory
+
+ :return: The epoch number of the checkpointed model
+ :rtype: int
+ """
+ print("--------------------------------------------")
+ print("Checkpoint file found!")
+ path = os.path.join(self.checkpoint_dir, 'checkpoint.pth')
+ print("Loading Checkpoint From: {}".format(path))
+ checkpoint = torch.load(path)
+ self.model.load_state_dict(checkpoint['model_state_dict'])
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ epoch_number = checkpoint['epoch']
+ loss = checkpoint['best_loss_on_test']
+ print("Checkpoint File Loaded - epoch_number: {} - loss: {}".format(epoch_number, loss))
+ print('Resuming training from epoch: {}'.format(epoch_number + 1))
+ print("--------------------------------------------")
+ return epoch_number
+
+ def load_model(self, filepath: Optional[str] = None):
+ """
+ load_model Loads the model from a specific directory
+
+ :param filepath: path to output directory, defaults to None
+ :type filepath: Optional[str], optional
+ """
+ path = filepath or os.path.join(self.model_dir, 'model.pth')
+ checkpoint = torch.load(path)
+ self.parallel()
+ self.model.load_state_dict(checkpoint)
+
+ def diagnose_region(self, region: np.ndarray, labels: dict = None):
+ """
+ diagnose_region Diagnoses the regions with a specific label
+
+ :param region: A 512 x 512 region
+ :type region: np.ndarray
+ :param labels: Dictionary of labels and their respective integer representations, defaults to None
+ :type labels: dict, optional
+ :return: Prediction of the region based on the labels provided
+ :rtype: str or int
+ """
+ self.model = self.model.to(self.device)
+ region = torch.Tensor(region[None, ::]).permute(0, 3, 1, 2).float().to(self.device)
+ output = self.model(region).to(self.device)
+ return output.detach().squeeze().cpu().numpy()
+
+ def diagnose_wsi(self, file_path: str, aggregate: Callable, classes: tuple, labels: dict = None):
+ """
+ diagnose_wsi Diagnoses the whole slide image with a specific label
+
+ :param file_path: File path to whole slide image
+ :type file_path: str
+ :param aggregate: Aggregation function to collapse the region classifications
+ :type aggregate: Callable
+ :param classes: Tuple of labels used for training
+ :type classes: tuple
+ :param labels: Dictionary of labels and their respective integer representations, defaults to None
+ :type labels: dict, optional
+ :return: Prediction of the region based on the labels provided
+ :rtype: str or int
+ """
+ region_classifications = {}
+ for i, region in enumerate(Image(file_path)):
+ region = region.to(self.device)
+ self.model.eval()
+ pred = self.diagnose_region(region, labels)
+ region_classifications[i] = pred
+ return aggregate(region_classifications, classes)
diff --git a/ml/put_model_into_webapp.ipynb b/ml/put_model_into_webapp.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..9bbaf86dafb3a086277978616bca01413f8892d5
--- /dev/null
+++ b/ml/put_model_into_webapp.ipynb
@@ -0,0 +1,360 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Prerequisites\n",
+ "\n",
+ "This file was last updated on May 3rd, 2022 - if any changes have been made to the class found in ./my_model.py then there may be issues with running this code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import time\n",
+ "import uuid\n",
+ "import utils\n",
+ "\n",
+ "import ray\n",
+ "import PIL\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "\n",
+ "from PIL import Image\n",
+ "from tqdm import tqdm\n",
+ "from scipy.special import softmax\n",
+ "\n",
+ "import my_model\n",
+ "from model_manager_for_web_app import ManagedModel\n",
+ "from model_manager_for_web_app import ModelManager\n",
+ "from filtration import FilterManager, FilterBlackAndWhite, FilterHSV, FilterFocusMeasure"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load in Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_dir = \"/home/ec2-user/models/3_class\"\n",
+ "training_run = \"dc5b3a4954de4c1f82e0939ac83e8ac3\"\n",
+ "model_file = \"model_2023-02-20T010424_37.pth\"\n",
+ "path_to_model_weights = f\"{model_dir}/{training_run}/checkpoints/{model_file}\"\n",
+ "\n",
+ "digpath_model = my_model.MyModel(\n",
+ " model = torch.load(path_to_model_weights, map_location=torch.device('cpu')),\n",
+ " loss_fn = None,\n",
+ " device = torch.device('cpu'),\n",
+ " checkpoint_dir= None,\n",
+ " optimizer=None,\n",
+ " model_dir=None\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# img = np.asarray(PIL.Image.open('/home/ec2-user/data/chips/Severe/84277T_001.tif_024153.png'))\n",
+ "# print(digpath_model.diagnose_region(img))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Wrapper on MyModel for WebApp Compatibility"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict_region(model, region, filtration, save_moderate, save_severe, save_dir):\n",
+ " if filtration(region) is True:\n",
+ " region_probs = model.diagnose_region(region)\n",
+ " region_probs = softmax(region_probs)\n",
+ "\n",
+ " #Save image\n",
+ "# if region_probs[2] > 0.75 and save_severe:\n",
+ "# im = Image.fromarray(region)\n",
+ "# #TODO: send to S3\n",
+ "# #im.save(f\"{save_dir}/{uuid.uuid4()}_severe_prediction.png\")\n",
+ "# del im\n",
+ "\n",
+ " del region\n",
+ "\n",
+ " return region_probs\n",
+ "\n",
+ "@ray.remote\n",
+ "def predict_region_parallel(predict_fuction, model, region, filtration, save_moderate, save_severe, save_dir):\n",
+ " return predict_fuction(model, region, filtration, save_moderate, save_severe, save_dir)\n",
+ "\n",
+ "\n",
+ "class WrappedModel(ManagedModel):\n",
+ " def __init__(\n",
+ " self,\n",
+ " model,\n",
+ " classes = ('Mild', 'Moderate', 'Severe'),\n",
+ " ):\n",
+ " self.model = model\n",
+ " self.classes = classes\n",
+ " self.filtration = FilterManager([\n",
+ " FilterBlackAndWhite(),\n",
+ " FilterHSV(),\n",
+ " FilterFocusMeasure()\n",
+ " ])\n",
+ " self._device = torch.device('cpu')\n",
+ "\n",
+ " def diagnose_region(self, region):\n",
+ " \"\"\" diagnose single region \"\"\"\n",
+ " return self.model.diagnose_region(region)\n",
+ "\n",
+ " def diagnose(\n",
+ " self,\n",
+ " region_stream,\n",
+ " total_chips,\n",
+ " db=None,\n",
+ " request_id=None,\n",
+ " process_id=None,\n",
+ " severe_prob_thresh = 0.75,\n",
+ " moderate_prob_thresh = 0.75,\n",
+ " severe_chip_thresh = 12,\n",
+ " moderate_chip_thresh = 100,\n",
+ " save_moderate = False,\n",
+ " save_severe = False,\n",
+ " save_dir='./',\n",
+ " parallel=False\n",
+ " ):\n",
+ " \"\"\"\n",
+ " model takes in a stream of regions (numpy arrays) and produces diagnosis\n",
+ "\n",
+ " Example:\n",
+ " # diagnosis is whichever category has the most 'votes'\n",
+ " votes = {'positive':0, 'negative':0}\n",
+ " for region in region_stream:\n",
+ " votes[self.process(region)] += 1\n",
+ " return max(votes, key=votes.get) # key with max value\n",
+ " \"\"\"\n",
+ " start = time.time()\n",
+ "\n",
+ " # first check to see if we can use hardware\n",
+ " if self._device is None:\n",
+ " self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ " self.model.model.to(self._device)\n",
+ "\n",
+ " if request_id is not None:\n",
+ " save_dir += request_id + '/'\n",
+ "\n",
+ " # initialize the votes\n",
+ " votes = {i:0 for i in range(len(self.classes))}\n",
+ "\n",
+ " # setup parallelization\n",
+ " if parallel:\n",
+ " model_id = ray.put(self.model)\n",
+ " filtration_id = ray.put(self.filtration)\n",
+ " predict_function_id = ray.put(predict_region)\n",
+ " futures = []\n",
+ "\n",
+ " parallel_batch_size = 500\n",
+ " for region_idx, region in enumerate(region_stream):\n",
+ " if parallel:\n",
+ " #parallel model predictions\n",
+ " futures.append(predict_region_parallel.remote(\n",
+ " predict_function_id,\n",
+ " model_id,\n",
+ " region,\n",
+ " filtration_id,\n",
+ " save_moderate,\n",
+ " save_severe,\n",
+ " save_dir\n",
+ " ))\n",
+ " else:\n",
+ " futures.append(predict_region(\n",
+ " self.model,\n",
+ " region,\n",
+ " self.filtration,\n",
+ " save_moderate,\n",
+ " save_severe,\n",
+ " save_dir\n",
+ " ))\n",
+ "\n",
+ " del region\n",
+ "\n",
+ " # Aggregate batch\n",
+ " if (\n",
+ " (region_idx % parallel_batch_size) == (parallel_batch_size - 1)\n",
+ " or region_idx == (total_chips - 1)\n",
+ " ):\n",
+ " if parallel:\n",
+ " prediction_results = ray.get(futures)\n",
+ " else:\n",
+ " prediction_results = futures\n",
+ "\n",
+ " #update votes\n",
+ " for probs in prediction_results:\n",
+ " if probs is not None:\n",
+ " if probs[2] > severe_prob_thresh:\n",
+ " region_diagnosis = 2\n",
+ " elif probs[1] > moderate_prob_thresh:\n",
+ " region_diagnosis = 1\n",
+ " else:\n",
+ " region_diagnosis = 0\n",
+ "\n",
+ " votes[region_diagnosis] += 1\n",
+ "\n",
+ " #update db\n",
+ " if db is not None and request_id is not None:\n",
+ " db.update_request(\n",
+ " request_id,\n",
+ " process_id,\n",
+ " total_chips,\n",
+ " votes[self.classes.index('Mild')],\n",
+ " votes[self.classes.index('Moderate')],\n",
+ " votes[self.classes.index('Severe')],\n",
+ " 'in progress'\n",
+ " )\n",
+ "\n",
+ " mild = votes[self.classes.index('Mild')]\n",
+ " moderate = votes[self.classes.index('Moderate')]\n",
+ " severe = votes[self.classes.index('Severe')]\n",
+ "\n",
+ " duration = time.time() - start\n",
+ " time_per_iteration = duration / (region_idx + 1)\n",
+ " expected_completion = time_per_iteration * total_chips\n",
+ "\n",
+ " print()\n",
+ " print(f\"{duration:.1f}s / {expected_completion:.1f}s - {time_per_iteration:.3f}s per region\")\n",
+ " print(f\"({region_idx + 1:05d}/{total_chips:05d})\")\n",
+ " print(f\"Mild {mild:05d} | Moderate: {moderate:05d} | Severe: {severe:05d}\")\n",
+ "\n",
+ " futures = []\n",
+ "\n",
+ " # aggregate the votes\n",
+ " if votes[self.classes.index('Severe')] >= severe_chip_thresh:\n",
+ " vote = self.classes.index('Severe')\n",
+ " elif votes[self.classes.index('Moderate')] >= moderate_chip_thresh:\n",
+ " vote = self.classes.index('Moderate')\n",
+ " else:\n",
+ " vote = self.classes.index('Mild')\n",
+ "\n",
+ " if db is not None and request_id is not None:\n",
+ " db.update_request(\n",
+ " request_id,\n",
+ " process_id,\n",
+ " total_chips,\n",
+ " votes[self.classes.index('Mild')],\n",
+ " votes[self.classes.index('Moderate')],\n",
+ " votes[self.classes.index('Severe')],\n",
+ " 'complete'\n",
+ " )\n",
+ "\n",
+ " return {\n",
+ " 'predictions': {\n",
+ " c: votes[i] for i, c in enumerate(self.classes)\n",
+ " },\n",
+ " 'vote': vote\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Prepare Model for Saving\n",
+ "\n",
+ "1. Wrap the model in a ManagedModel class (you will need to create a subclass of ManagedModel just as above)\n",
+ "2. Register any dependencies that might not be available to the WebApp when this model is deserialized. \n",
+ " - To identify whether you need to register a dependency, consider the code used to create the serialized object that may not be available to the WebApp when deserializing. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_wrapped = WrappedModel(model=digpath_model)\n",
+ "model_manager = ModelManager()\n",
+ "model_name = \"zrahn_efficientnet_0219_full_parallel\"\n",
+ "model_manager.save_model(\n",
+ " model_name = model_name,\n",
+ " model = model_wrapped,\n",
+ " model_info = {\n",
+ " \"info\": \"pretrained efficientnet fine-tuned 02/19/2023 on Mild/Moderate/Severe chips (including simulated). Parallelized predictions\"\n",
+ " },\n",
+ " overwrite_model=True,\n",
+ " dependency_modules = [\n",
+ " my_model,\n",
+ " utils\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_dir = '/home/ec2-user/models/wrapped_models/'\n",
+ "os.rename(f'{model_manager.model_dir}/{model_name}.model', f'{save_dir}/{model_name}.model')\n",
+ "os.rename(f'{model_manager.model_dir}/{model_name}.model_info', f'{save_dir}/{model_name}.model_info')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/ml/test_e2e.py b/ml/test_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..96396c808fe09ca781fe36172b83f3c27f081079
--- /dev/null
+++ b/ml/test_e2e.py
@@ -0,0 +1,40 @@
+import boto3
+import json
+import requests
+import time
+
+def process_file(s3_file):
+ img_dir = '/home/ec2-user/data/images'
+ fname = s3_file.split('/')[-1]
+ img_file = f"{img_dir}/{fname}"
+
+ print(f'Downloading {s3_file}')
+ s3_client.download_file(image_s3_bucket_name, s3_file, img_file)
+
+ response = requests.get('http://localhost:5000/diagnose', json={"file": img_file})
+
+ request_id = response.json()['request_id']
+ print(f"Request ID: {request_id}")
+
+ status = 'in progress'
+ while status not in ('complete', 'error'):
+ response = requests.get('http://localhost:5000/request_info', json={"request_id": request_id})
+ status = response.json()['status']
+ time.sleep(20)
+
+ print(json.dumps(response.json(), indent=2))
+
+
+image_s3_bucket_name = 'digpath-data' #TODO: digpath-data-new & -batch-3
+
+s3_client = boto3.client('s3')
+s3_resource = boto3.resource('s3')
+image_bucket = s3_resource.Bucket(image_s3_bucket_name)
+
+for label in ['Moderate', 'Mild', 'Severe']:
+ for s3_obj in image_bucket.objects.filter(Prefix=f'{label}/'):
+
+ if s3_obj.key == f'{label}/':
+ continue
+
+ process_file(s3_obj.key)
diff --git a/training_image/Dockerfile b/training_image/Dockerfile
deleted file mode 100644
index 90f7bd85257507b24a6c83ad0cdcb28dfab016ab..0000000000000000000000000000000000000000
--- a/training_image/Dockerfile
+++ /dev/null
@@ -1,12 +0,0 @@
-FROM continuumio/miniconda3:4.12.0
-
-WORKDIR /opt
-
-COPY . .
-
-RUN apt update && apt install -y libglu1-mesa-dev
-RUN conda install -y -c conda-forge pyvips
-RUN pip install opencv-python
-RUN pip install -e UnifiedImageReader-main
-
-CMD bash
\ No newline at end of file
diff --git a/training_image/UnifiedImageReader-main/.devcontainer/Dockerfile b/training_image/UnifiedImageReader-main/.devcontainer/Dockerfile
deleted file mode 100644
index 8a6d51c257e6de0a94b9d1bae26ccbdd0a5c6e5d..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/.devcontainer/Dockerfile
+++ /dev/null
@@ -1,16 +0,0 @@
-# See here for image contents: https://github.com/microsoft/vscode-dev-containers/tree/v0.209.6/containers/python-3-anaconda/.devcontainer/base.Dockerfile
-
-FROM mcr.microsoft.com/vscode/devcontainers/anaconda:0-3
-
-# [Choice] Node.js version: none, lts/*, 16, 14, 12, 10
-ARG NODE_VERSION="none"
-RUN if [ "${NODE_VERSION}" != "none" ]; then su vscode -c "umask 0002 && . /usr/local/share/nvm/nvm.sh && nvm install ${NODE_VERSION} 2>&1"; fi
-
-# Copy environment.yml (if found) to a temp location so we update the environment. Also
-# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
-COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
-RUN conda env create -n vips -f /tmp/conda-tmp/environment.yml
-
-# [Optional] Uncomment this section to install additional OS packages.
-# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
-# && apt-get -y install --no-install-recommends <your-package-list-here>
diff --git a/training_image/UnifiedImageReader-main/.devcontainer/add-notice.sh b/training_image/UnifiedImageReader-main/.devcontainer/add-notice.sh
deleted file mode 100644
index d26a5df14ab41587334f110b69ceca02bad2b499..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/.devcontainer/add-notice.sh
+++ /dev/null
@@ -1,19 +0,0 @@
-# Display a notice when not running in GitHub Codespaces
-
-cat << 'EOF' > /usr/local/etc/vscode-dev-containers/conda-notice.txt
-When using "conda" from outside of GitHub Codespaces, note the Anaconda repository
-contains restrictions on commercial use that may impact certain organizations. See
-https://aka.ms/vscode-remote/conda/anaconda
-
-EOF
-
-notice_script="$(cat << 'EOF'
-if [ -t 1 ] && [ "${IGNORE_NOTICE}" != "true" ] && [ "${TERM_PROGRAM}" = "vscode" ] && [ "${CODESPACES}" != "true" ] && [ ! -f "$HOME/.config/vscode-dev-containers/conda-notice-already-displayed" ]; then
- cat "/usr/local/etc/vscode-dev-containers/conda-notice.txt"
- mkdir -p "$HOME/.config/vscode-dev-containers"
- ((sleep 10s; touch "$HOME/.config/vscode-dev-containers/conda-notice-already-displayed") &)
-fi
-EOF
-)"
-
-echo "${notice_script}" | tee -a /etc/bash.bashrc >> /etc/zsh/zshrc
diff --git a/training_image/UnifiedImageReader-main/.devcontainer/devcontainer.json b/training_image/UnifiedImageReader-main/.devcontainer/devcontainer.json
deleted file mode 100644
index 8b7118eafa3255575bd4b7c30d4e4b31dbaa62ea..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/.devcontainer/devcontainer.json
+++ /dev/null
@@ -1,45 +0,0 @@
-// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
-// https://github.com/microsoft/vscode-dev-containers/tree/v0.209.6/containers/python-3-anaconda
-{
- "name": "Anaconda (Python 3)",
- "build": {
- "context": "..",
- "dockerfile": "Dockerfile",
- "args": {
- "NODE_VERSION": "none"
- }
- },
-
- // Set *default* container specific settings.json values on container create.
- "settings": {
- "python.defaultInterpreterPath": "/opt/conda/bin/python",
- "python.linting.enabled": true,
- "python.linting.pylintEnabled": true,
- "python.formatting.autopep8Path": "/opt/conda/bin/autopep8",
- "python.formatting.yapfPath": "/opt/conda/bin/yapf",
- "python.linting.flake8Path": "/opt/conda/bin/flake8",
- "python.linting.pycodestylePath": "/opt/conda/bin/pycodestyle",
- "python.linting.pydocstylePath": "/opt/conda/bin/pydocstyle",
- "python.linting.pylintPath": "/opt/conda/bin/pylint"
- },
-
- // Add the IDs of extensions you want installed when the container is created.
- "extensions": [
- "ms-python.python",
- "ms-python.vscode-pylance"
- ],
-
- // Use 'forwardPorts' to make a list of ports inside the container available locally.
- // "forwardPorts": [],
-
- // Use 'postCreateCommand' to run commands after the container is created.
- // "postCreateCommand": "python --version",
- "postAttachCommand": "autopep8 --in-place -r /workspaces/UnifiedImageReader",
-
- // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
- "remoteUser": "vscode",
- "features": {
- "git": "os-provided",
- "sshd": "latest"
- }
-}
diff --git a/training_image/UnifiedImageReader-main/.devcontainer/noop.txt b/training_image/UnifiedImageReader-main/.devcontainer/noop.txt
deleted file mode 100644
index dde8dc3c10bc8e836a31b83efc366efd768db624..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/.devcontainer/noop.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-This file copied into the container along with environment.yml* from the parent
-folder. This file is included to prevents the Dockerfile COPY instruction from
-failing if no environment.yml is found.
\ No newline at end of file
diff --git a/training_image/UnifiedImageReader-main/.gitignore b/training_image/UnifiedImageReader-main/.gitignore
deleted file mode 100644
index 25ff4b01b4e18745fa9d6b63b7397dda4fe19023..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/.gitignore
+++ /dev/null
@@ -1,129 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
diff --git a/training_image/UnifiedImageReader-main/README.md b/training_image/UnifiedImageReader-main/README.md
deleted file mode 100644
index 2faae55e4ed521fcbd5480b5151df7c3212e1270..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/README.md
+++ /dev/null
@@ -1,29 +0,0 @@
-# UnifiedImageReader
-
-```mermaid
-classDiagram
- Image *-- ImageReader
- ImageReader *-- Adapter
- Adapter <|-- VIPS
- Adapter <|-- SlideIO
- class Image {
- get_region()
- number_of_regions()
- }
- class ImageReader {
- get_region()
- number_of_regions()
- validate_region()
- region_index_to_coordinates()
- }
- class Adapter {
- <<abstract>>
- get_region()
- get_width()
- get_height()
- }
-```
-
-## Installation
-
-All of the dependencies for the adapters require manual installation because of the dll dependencies. Contact Adin at adinbsolomon@gmail.com with any questions.
diff --git a/training_image/UnifiedImageReader-main/pyproject.toml b/training_image/UnifiedImageReader-main/pyproject.toml
deleted file mode 100644
index b5a3c468d9e85e7fa7469c3a90d47b48ab93e54a..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/pyproject.toml
+++ /dev/null
@@ -1,6 +0,0 @@
-[build-system]
-requires = [
- "setuptools>=42",
- "wheel"
-]
-build-backend = "setuptools.build_meta"
\ No newline at end of file
diff --git a/training_image/UnifiedImageReader-main/setup.cfg b/training_image/UnifiedImageReader-main/setup.cfg
deleted file mode 100644
index 9acc46377e3a01bc008c1efcd24c8c75675fde5f..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/setup.cfg
+++ /dev/null
@@ -1,25 +0,0 @@
-[metadata]
-name = unified-image-reader
-version = 0.0.1
-author = Adin Solomon
-author_email = adinbsolomon@gmail.com
-description = testing
-long_description = file: README.md
-long_description_content_type = text/markdown
-url = https://github.com/Digital-Pathology/UnifiedImageReader
-classifiers =
- Programming Language :: Python :: 3
- License :: OSI Approved :: MIT License
- Operating System :: OS Independent
-
-[options]
-package_dir =
- = src
-packages = find:
-install_requires =
- numpy
- pyvips
- slideio
-
-[options.packages.find]
-where = src
\ No newline at end of file
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/__init__.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/__init__.py
deleted file mode 100644
index 4ec756ecc2e18a4834a52c7a82100fa3a0c09d52..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-
-from .image import Image
-from .image_reader import ImageReader
-
-from . import util
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/__init__.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/__init__.py
deleted file mode 100644
index b03686d54f1882915fa5f2a4d25a50e03180e6ee..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-
-from .adapter import Adapter
-from .slideio import SlideIO
-from .vips import VIPS
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/adapter.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/adapter.py
deleted file mode 100644
index 6d4c046d75534ec77874857d820dff64410232f2..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/adapter.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""
- Adapter
-
- An implementation of image reading behavior that may map specific libraries to working with specific image formats
-"""
-import abc
-from typing import Iterable
-
-import numpy as np
-
-
-class Adapter(abc.ABC):
-
- @abc.abstractmethod
- def get_region(self, region_coordinates: Iterable, region_dims: Iterable) -> np.ndarray:
- """get_region Get a pixel region of the image using the adapter library's implementation
-
- :param region_coordinates: A set of (width, height) coordinates representing the top-left pixel of the region
- :type region_coordinates: Iterable
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: A numpy array representative of the pixel region from the image
- :rtype: np.ndarray
- """
- pass
-
- @abc.abstractmethod
- def get_width() -> int:
- """
- Get the width property of the image using the adapter library's implementation
- """
- pass
-
- @abc.abstractmethod
- def get_height() -> int:
- """
- Get the height property of the image using the adapter library's implementation
- """
- pass
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/config.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/config.py
deleted file mode 100644
index 9909acc47d6666c4be2990fdcc2a1c04b2ec0024..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/config.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-VIPS_GET_REGION = "IMAGE_CROP" # alternatively "REGION_FETCH"
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/slideio.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/slideio.py
deleted file mode 100644
index 70c4ca089d696b9b41c748e85b5948145e90bca9..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/slideio.py
+++ /dev/null
@@ -1,57 +0,0 @@
-"""
- SlideIO Adapter
-
- An adapter that uses the SlideIO library to implement image reading behavior
- Adapter currently mapped to reading .svs files
-"""
-import numpy as np
-
-try:
- import slideio
-except Exception as e:
- print("You have an issue with your SlideIO installation, it may be because of the dependency on Openslide. Contact Adin at adinbsolomon@gmail.com with any questions!")
- raise e
-
-from .adapter import Adapter
-
-
-class SlideIO(Adapter):
-
- def __init__(self, filepath):
- """__init__ Initialize SlideIO adapter object
-
- :param filepath: Filepath to image file to be opened
- :type filepath: str
- """
-
- self._image = slideio.open_slide(filepath, "SVS").get_scene(0)
-
- def get_width(self):
- """get_width Get the width property of the image using SlideIO's implementation
-
- :return: Height in pixels
- :rtype: int
- """
- return self._image.size[0]
-
- def get_height(self):
- """get_height Get the height property of the image using SlideIO's implementation
-
- :return: Width in pixels
- :rtype: int
- """
- return self._image.size[1]
-
- def get_region(self, region_coordinates, region_dims) -> np.ndarray:
- """get_region Get a pixel region of the image using SlideIO's implementation
-
- :param region_coordinates: A set of (width, height) coordinates representing the top-left pixel of the region
- :type region_coordinates: Iterable
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: A numpy array representative of the pixel region from the image
- :rtype: np.ndarray
- """
- """ Calls the read_block method of a SlideIO Scene object to create an unscaled rectangular region of the image as a numpy array """
- np_array = self._image.read_block((*region_coordinates, *region_dims))
- return np_array
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/vips.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/vips.py
deleted file mode 100644
index e720503631184fd58479687a1b4a3e98c82d151f..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/adapters/vips.py
+++ /dev/null
@@ -1,86 +0,0 @@
-"""
- VIPS Adapter
-
- An adapter that uses pyvips, the Python extension of the libvips library, to implement image reading behavior
- Adapter currently mapped to reading .tif, tiff files
-"""
-
-import numpy as np
-
-try:
- import pyvips
-except Exception as e:
- print("You have an issue with your pyvips installation, it may be because of the dependency on libvips. Contact Adin at adinbsolomon@gmail.com with any questions!")
- raise e
-
-from .adapter import Adapter
-from . import config
-
-FORMAT_TO_DTYPE = {
- 'uchar': np.uint8,
- 'char': np.int8,
- 'ushort': np.uint16,
- 'short': np.int16,
- 'uint': np.uint32,
- 'int': np.int32,
- 'float': np.float32,
- 'double': np.float64,
- 'complex': np.complex64,
- 'dpcomplex': np.complex128
-}
-
-
-class VIPS(Adapter):
-
- def __init__(self, filepath: str):
- """__init__ Initialize VIPS adapter object
-
- :param filepath: Filepath to image file to be opened
- :type filepath: str
- """
- self._image = pyvips.Image.new_from_file(filepath, access="random")
-
- def get_width(self) -> int:
- """get_height Get the height property of the image using VIPS' implementation
-
- :return: Height in pixels
- :rtype: int
- """
- return self._image.width
-
- def get_height(self) -> int:
- """get_height Get the height property of the image using VIPS' implementation
-
- :return: Height in pixels
- :rtype: int
- """
- return self._image.height
-
- def get_region(self, region_coordinates, region_dims) -> np.ndarray:
- """get_region Get a pixel region of the image using VIPS' implementation
-
- :param region_coordinates: A set of (width, height) coordinates representing the top-left pixel of the region
- :type region_coordinates: Iterable
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: A numpy array representative of the pixel region from the image
- :rtype: np.ndarray
- """
- if config.VIPS_GET_REGION == "IMAGE_CROP":
- output_img = self._image.crop(*region_coordinates, *region_dims)
- np_output = np.ndarray(
- buffer=output_img.write_to_memory(),
- dtype=FORMAT_TO_DTYPE[output_img.format],
- shape=[output_img.height, output_img.width, output_img.bands]
- )
- return np_output
- elif config.VIPS_GET_REGION == "REGION_FETCH":
- vips_region = pyvips.Region.new(self._image)
- bytestring_buffer = vips_region.fetch(
- *region_coordinates, *region_dims)
- np_output = np.frombuffer(bytestring_buffer, dtype=np.uint8)
- region = np_output.reshape(*region_dims, 3)
- return region
- else:
- raise Exception(
- f"Invalid vips get region mode {config.VIPS_GET_REGION=}")
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/config.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/config.py
deleted file mode 100644
index 7b7f83508f9166ce466518ae24c943cca1fafe6c..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/config.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-DEFAULT_REGION_DIMS = (512, 512)
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/image.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/image.py
deleted file mode 100644
index 6d05e7f10a592619fa863ed7bfe9810b036353b0..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/image.py
+++ /dev/null
@@ -1,123 +0,0 @@
-
-"""
- An interface into optimized image reading behavior with optional overriding.
-"""
-
-import contextlib
-from typing import Optional
-
-import numpy as np
-
-from . import config
-from . import image_reader
-
-
-class Image(contextlib.AbstractContextManager):
-
- """
- Image An image to be streamed into a specialized reader
- """
-
- def __init__(self, filepath, reader=None):
- """__init__ Initialize Image object
-
- :param filepath: Filepath to image file to be opened
- :type filepath: str
- :param reader: Interface to reading the image file, defaults to None
- :type reader: ImageReader or custom class supportive of the same functions, optional
- """
- self.filepath = filepath
- self.reader = reader or image_reader.ImageReader(filepath)
- self._iter = None
-
- def get_region(self, region_identifier, region_dims=config.DEFAULT_REGION_DIMS) -> np.ndarray:
- """
- get_region Get a pixel region from the image
-
- :param region_identifier: A set of (width, height) coordinates or an indexed region based on region dimensions
- :type region_identifier: Union[int, Iterable]
- :param region_dims: A set of (width, height) coordinates representing the region dimensions, defaults to DEFAULT_REGION_DIMS
- :type region_dims: Iterable, optional
- :return: A numpy array representative of the pixel region from the image
- :rtype: np.ndarray
- """
- return self.reader.get_region(region_identifier, region_dims)
-
- def number_of_regions(self, region_dims=config.DEFAULT_REGION_DIMS) -> int:
- """
- number_of_regions Get total number of regions from the image based on region dimensions
-
- :param region_dims: A set of (width, height) coordinates representing the region dimensions, defaults to DEFAULT_REGION_DIMS
- :type region_dims: Iterable, optional
- :return: Number of regions in the image
- :rtype: int
- """
- return self.reader.number_of_regions(region_dims)
-
- @property
- def width(self):
- """
- width Get the width property of the image using its reader
-
- :return: Width in pixels
- :rtype: int
- """
- return self.reader.width
-
- @property
- def height(self):
- """
- height Get the height property of the image using its reader
-
- :return: Height in pixels
- :rtype: int
- """
- return self.reader.height
-
- @property
- def dims(self):
- """
- dims Get the width and height properties of the image
-
- :return: Width and height in pixels
- :rtype: Tuple[int]
- """
- return self.width, self.height
-
- def __iter__(self):
- """
- __iter__ Initialize Image object iterator
-
- :raises Exception: Iterator already initialized but is called again
- :return: Iterator for Image object
- :rtype: Image
- """
- self._iter = 0
- return self
-
- def __next__(self):
- """
- __next__ Get the next pixel region index in a sequence of iterating through an Image object
-
- :raises StopIteration: Iterator has reached the last region in the image
- :return: Next pixel region index
- :rtype: int
- """
- if self._iter >= self.number_of_regions():
- raise StopIteration
- else:
- region = self.get_region(self._iter)
- self._iter += 1
- return region
-
- def __len__(self):
- """
- __len__ Get the number of pixel regions in an iterable sequence of an Image object
-
- :return: The number of pixel regions in the Image object
- :rtype: int
- """
- return self.number_of_regions()
-
- def __exit__(self, **kwargs) -> Optional[bool]:
- return super().__exit__(**kwargs)
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/image_reader.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/image_reader.py
deleted file mode 100644
index 2dde19c85dd09438f6bd0bf43358c72d01c7b677..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/image_reader.py
+++ /dev/null
@@ -1,285 +0,0 @@
-"""
- An ImageReader controls the behavior of the image interface. It can either utilize an adapter on a library or custom behavior.
-"""
-
-import os
-from typing import Any, Iterable, Optional, Tuple, Union
-
-import cv2 as cv
-import numpy as np
-
-from unified_image_reader.adapters import Adapter, SlideIO, VIPS
-
-FORMAT_ADAPTER_MAP = {
- "tif": VIPS,
- "tiff": VIPS,
- "svs": SlideIO
-}
-
-
-class UnsupportedFormatException(Exception):
- pass
-
-
-class InvalidCoordinatesException(Exception):
- pass
-
-
-class InvalidDimensionsException(Exception):
- pass
-
-
-class ImageReader():
- """
- ImageReader Interface between images and adapters which specify reading behavior
-
- :raises UnsupportedFormatException: The adapter does not support the image format
- :raises TypeError: Enforces provided type hinting for region_identifier arguments
- :raises IndexError: The top-left pixel or pixel region dimensions are out of bounds of the image dimensions
- :raises InvalidCoordinatesException: The top-left pixel rwas not provided in (width, height) format
- :raises InvalidDimensionsException: Dimensions of the pixel region were not provided in (width, height) format
- """
-
- def __init__(self, filepath: str, adapter: Union[Adapter, None] = None):
- """
- __init__ Initialize ImageReader object
-
- :param filepath: Filepath to image file to be opened
- :type filepath: str
- :param adapter: Object which specifies reading behavior, defaults to None
- :type adapter: Union[Adapter, None], optional
- :raises UnsupportedFormatException: The adapter does not support the image format
- """
- # process filepath
- assert os.path.isfile(
- filepath), f"filepath is not a file --> {filepath}"
-
- self.filepath = filepath
- # initialize the adapter
- self.adapter = None
- if adapter is None: # choose based on file format
- image_format = self.filepath.split('.')[-1]
- adapter = FORMAT_ADAPTER_MAP.get(image_format)
- if adapter is None:
- raise UnsupportedFormatException(image_format)
- self.adapter = adapter(filepath)
-
- def get_region(self, region_identifier: Union[int, Iterable], region_dims: Iterable):
- """
- get_region Get a pixel region from an image using an adapter's implementation after validation and extracting region data
-
- :raises TypeError: The starting pixels or pixel regions are out of bounds of the image dimensions
- :param region_identifier: A set of (width, height) coordinates or an indexed region based on region dimensions
- :type region_identifier: Union[int, Iterable]
- :param region_dims: A set of (weight, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: A numpy array representative of the pixel region from the image
- :rtype: np.ndarray
- """
- # Make sure that region_coordinates is a tuple of length 2
- region_coordinates = None
- if isinstance(region_identifier, int):
- region_coordinates = self.region_index_to_coordinates(
- region_identifier, region_dims)
- elif isinstance(region_identifier, Iterable):
- assert (len(region_identifier) == 2)
- region_coordinates = region_identifier
- else:
- raise TypeError(
- f"region_identifier should be either int or Iterable but is {type(region_identifier)=}, {region_identifier=}")
- # make sure that the region is in bounds
- self.validate_region(region_coordinates, region_dims)
- # call the implementation
- return self._get_region(region_coordinates, region_dims)
-
- def _get_region(self, region_coordinates, region_dims) -> np.ndarray:
- """
- _get_region Call an adapter's implementation to get a pixel region from an image
-
- :param region_coordinates: A set of (width, height) coordinates representing the top-left pixel of the region
- :type region_coordinates: Iterable
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: Implementation resulting in a numpy array representative of the pixel region from the image
- :rtype: np.ndarray
- """
-
- return self.adapter.get_region(region_coordinates, region_dims)
-
- def number_of_regions(self, region_dims: Iterable):
- """
- number_of_regions Calculates the number of regions in the image based on the dimensions of each region
-
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: The number of regions
- :rtype: int
- """
-
- width, height = region_dims
- return (self.width // width) * (self.height // height)
-
- def validate_region(self, region_coordinates: Iterable, region_dims: Iterable) -> None:
- """
- validate_region Checks that a region is within the bounds of the image
-
- :param region_coordinates: A set of (width, height) coordinates representing the top-left pixel of the region
- :type region_coordinates: Iterable
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :raises IndexError: The top-left pixel or pixel region dimensions are out of the bounds of the image dimensions
- :raises InvalidCoordinatesException: The top-left pixel was not presented in (width, height) format
- :raises InvalidDimensionsException: Dimensions of the pixel region were not presented in (width, height) format
- """
-
- def not_valid():
- """
- not_valid Wrapper function to raise an error on invalid coordinates or dimensions
-
- :raises IndexError: The top-left pixel or pixel region dimensions are out of the bounds of the image dimensions
- """
-
- raise IndexError(region_coordinates, region_dims, self.dims)
- # first ensure coordinates are in bounds
- if not (len(region_coordinates) == 2):
- raise InvalidCoordinatesException(region_coordinates)
- left, top = region_coordinates
- if not (0 <= left < self.width):
- not_valid()
- if not (0 <= top < self.height):
- not_valid()
- # then check dimensions with coordinates
- if not (len(region_dims) == 2):
- raise InvalidDimensionsException(region_dims)
- region_width, region_height = region_dims
- if not (0 < region_width and left+region_width <= self.width):
- not_valid()
- if not (0 < region_height and top+region_height <= self.height):
- not_valid()
-
- def region_index_to_coordinates(self, region_index: int, region_dims: Iterable):
- """
- region_index_to_coordinates Converts the index of a region to coordinates of the top-left pixel of the region
-
- :param region_index: The nth region of the image (where n >= 0) based on region dimensions
- :type region_index: int
- :param region_dims: A set of (width, height) coordinates representing the region dimensions
- :type region_dims: Iterable
- :return: A set of (width, height) coordinates representing the top-left pixel of the region
- :rtype: Iterable
- """
-
- region_width, region_height = region_dims
- width_regions = self.width // region_width
- left = (region_index % width_regions) * region_width
- top = (region_index // width_regions) * region_height
- return (left, top)
-
- @property
- def width(self):
- """
- width Get the width property of the image using the adapter's implementation
-
- :return: Width in pixels
- :rtype: int
- """
- return self.adapter.get_width()
-
- @property
- def height(self):
- """
- height Get the height property of the image using the adapter's implementation
-
- :return: Height in pixels
- :rtype: int
- """
- return self.adapter.get_height()
-
- @property
- def dims(self):
- """
- dims Get the width and height property of the image using the adapter's implementation
-
- :return: Width and height in pixels
- :rtype: Iterable
- """
- return self.width, self.height
-
-
-class ImageReaderDirectory(ImageReader):
-
- """
- Treats a collection of images as a single image whereby regions don't have locations, only alphabetically-organized indices.
- This works with both a directory and a list of image files.
- """
-
- def __init__(self, data: Union[str, list, tuple]):
- """
- __init__
-
- :param data: the location(s) of constituent images
- :type data: str
- :raises Exception: when data is a string but isn't a directory
- :raises TypeError: when data is neither a string nor list/tuple
- :raises Exception: when a file in data (when data is a list) doesn't exist as a file
- """
- if isinstance(data, str):
- self._dir = data
- if not os.path.isdir(self._dir):
- raise Exception(f"{data=} should be a path to a directory")
- self._region_files = [os.path.join(
- self._dir, p) for p in os.listdir(self._dir)]
- elif isinstance(data, [list, tuple]):
- self._region_files = data
- for region_filepath in self._region_files:
- if not os.path.isfile(region_filepath):
- raise Exception(
- f"self._region_files should be composed of filepaths to existing image files but includes {region_filepath}")
- else:
- raise TypeError(f"Didn't expect {type(data)=}, {data=}")
- self._region_files.sort()
-
- def get_region(self, region_identifier: int, region_dims: Optional[Any] = None) -> np.ndarray:
- """
- get_region reads in the image at self._region_files[region_identifier]
-
- :param region_identifier: the index of the file read (files are indexed alphabetically)
- :type region_identifier: int
- :param region_dims: IGNORED - the regions will be whatever the regions of the image file are, defaults to None
- :type region_dims: Any, optional
- :raises NotImplementedError: if the region identifier isn't an index
- :raises IndexError: if region_identifier isn't in range
- :return: region (the image in the file in question)
- :rtype: np.ndarray
- """
- if not isinstance(region_identifier, int):
- raise NotImplementedError(
- "This ImageReader only operates on aggregated region files which are indexed alphabetically. Region coordinates are not supported.")
- if not (0 <= region_identifier < self.number_of_regions()):
- raise IndexError(
- f"{region_identifier=}, {self.number_of_regions()=}")
- region_filepath = self._region_files[region_identifier]
- return cv.imread(region_filepath)
-
- def number_of_regions(self, region_dims: Optional[Any] = None) -> int:
- """
- number_of_regions the number of region in the image (in this case, the number of image files)
-
- :param region_dims: IGNORED - the dimensions of the regions in this image are the dimensions of the image in the files, defaults to None
- :type region_dims: Optional[Any], optional
- :return: the number of regions in this image (the number of the image files)
- :rtype: int
- """
- return len(self._region_files)
-
- @property
- def width(self):
- raise NotImplementedError()
-
- @property
- def height(self):
- raise NotImplementedError()
-
- @property
- def dims(self):
- raise NotImplementedError()
diff --git a/training_image/UnifiedImageReader-main/src/unified_image_reader/util.py b/training_image/UnifiedImageReader-main/src/unified_image_reader/util.py
deleted file mode 100644
index 3b40650b15e8183f08492896cc45a000bc49d4c3..0000000000000000000000000000000000000000
--- a/training_image/UnifiedImageReader-main/src/unified_image_reader/util.py
+++ /dev/null
@@ -1,35 +0,0 @@
-
-"""
- Utility functions and classes for the Unified Image Reader
-"""
-
-import os
-from typing import List, NewType, Tuple, Union
-
-RegionDimensions = NewType('RegionDimensions', Tuple[int, int])
-
-RegionIndex = NewType('RegionIndex', int)
-RegionCoordinates = NewType('RegionCoordinates', Tuple[int, int])
-RegionIdentifier = NewType(
- 'RegionIdentifier', Union[RegionIndex, RegionCoordinates])
-
-FilePath = NewType('FilePath', str)
-
-
-def listdir_recursive(path: FilePath) -> List[FilePath]:
- """
- listdir_recursive lists files (not directories) recursively from path
-
- :param path: the path to the directory whose files should be listed recursively
- :type path: FilePath
- :return: a list of filepaths relative to path
- :rtype: List[FilePath]
- """
- files = []
- walk = os.walk(path)
- for (directory_pointer, _, file_nodes) in walk:
- files += [
- os.path.join(directory_pointer, file_node)
- for file_node in file_nodes
- ]
- return files
diff --git a/training_image/UnifiedImageReader-main/tests/images/test-image.tiff b/training_image/UnifiedImageReader-main/tests/images/test-image.tiff
deleted file mode 100644
index 4e71ab8e800c9b6eaf30ba43ea3a70c6d49f63d5..0000000000000000000000000000000000000000
Binary files a/training_image/UnifiedImageReader-main/tests/images/test-image.tiff and /dev/null differ