-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
102 lines (80 loc) · 3.05 KB
/
predict.py
File metadata and controls
102 lines (80 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
# coding: utf-8
import argparse
import sqlite3
import time
from glob import glob
from pathlib import Path
from loguru import logger
from tqdm import tqdm
from ultralytics import YOLO
def opts() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('-d',
'--db-path',
help='Path to the sqlite database file',
type=str,
required=True)
parser.add_argument('-m',
'--model-path',
help='Path to the known_unknown model file',
type=str,
required=True)
parser.add_argument('-i',
'--images-dir',
help='Path to the images directory',
type=str,
required=True)
return parser.parse_args()
def run(db_path: str, model_path: str, images_dir: str) -> None:
conn = sqlite3.connect(db_path, timeout=30)
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL;")
conn.commit()
table_prefix = 'ainatype'
cursor.execute(f'''CREATE TABLE IF NOT EXISTS {table_prefix}_detections
(image TEXT PRIMARY KEY,
dead_conf REAL,
live_animal_conf REAL,
other_conf REAL,
scat_conf REAL,
tracks_conf REAL)''')
cursor.execute(f'''CREATE TABLE IF NOT EXISTS {table_prefix}_errored
(image TEXT PRIMARY KEY)''')
model = YOLO(model_path)
imgs = glob(f'{images_dir}/*')
len_1 = len(imgs)
existing_images = set(row[0] for row in cursor.execute(
f'SELECT image FROM {table_prefix}_detections'))
existing_errored_images = set(row[0] for row in cursor.execute(
f'SELECT image FROM {table_prefix}_errored'))
imgs = [x for x in imgs if Path(x).name not in existing_images]
len_2 = len(imgs)
logger.info(f'Excluded {len_1 - len_2} existing images.')
time.sleep(2)
batch_size = 1000
batch_count = 0
for img in tqdm(imgs):
try:
res = model(img)[0]
probs = [Path(img).name
] + [round(x, 2)
for x in res.probs.data.tolist()]
cursor.execute(
f'INSERT INTO {table_prefix}_detections VALUES (?, ?, ?, ?, ?, ?)', probs)
batch_count += 1
if batch_count % batch_size == 0:
conn.commit() # Commit every 1000 iterations
logger.info(f'Committed batch {batch_count // batch_size}')
except Exception as e:
logger.error(e)
if Path(img).name not in existing_errored_images:
cursor.execute(f'INSERT INTO {table_prefix}_errored VALUES (?)',
(Path(img).name, ))
conn.commit()
conn.close()
def main() -> None:
args = opts()
run(args.db_path, args.model_path, args.images_dir)
if __name__ == '__main__':
main()