-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtaskqueue.py
More file actions
88 lines (81 loc) · 3.56 KB
/
Copy pathtaskqueue.py
File metadata and controls
88 lines (81 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from constants import SQLITE_DB
from db import write_to_db, read_from_db
from embeddings import pdf_to_embeddings
import threading
import time
from extra.logger_config import setup_logger
logger = setup_logger(__name__)
class FileQueue:
def __init__(self, db_path=SQLITE_DB, max_size=30):
'''
db_path: path to the sqlite3 database
table_name: name of the table to use
max_size: maximum number of items to store in the queue
'''
self.db_path = db_path
self.max_size = max_size
self.processing_cnt = 0
self.lock = threading.Lock()
def start_processing(self):
'''Start processing files in queue
'''
t = threading.Thread(target=self.process_files)
t.start()
def add(self, user, fname, dtype, source_id=None):
'''Add file to db and return source_id
'''
from extra.utils import generate_uuid
source_id = generate_uuid(length=8) if not source_id else source_id
add_file_q = f'INSERT INTO data_sources (source_id,user_id,name,dtype) VALUES (?,?,?,?)'
file_entry = [source_id,user,fname,dtype]
with self.lock:
write_to_db(add_file_q, file_entry)
return source_id
def get_file(self):
'''Get pending from db
Returns filename: str
'''
read_q = f'SELECT source_id, user_id, name, status, dtype FROM data_sources WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
with self.lock:
result = read_from_db(read_q)
result = result[0] if result else None
if result:
update_q = f'UPDATE data_sources SET status = "processing" WHERE source_id = ?'
write_to_db(update_q, [result['source_id']])
self.processing_cnt += 1
return result
else:
return None
def process_files(self):
'''Process files in queue
'''
while True:
if self.processing_cnt < self.max_size:
result = self.get_file()
if result:
fname = result['name']
user_id = result['user_id']
source_id = result['source_id']
if fname.lower().endswith('.pdf'):
processing_output = pdf_to_embeddings(fname=fname, user_id=user_id, source_id=source_id)
status = processing_output['status']
n_tokens = processing_output['n_tokens']
self.mark_as_processed(source_id, user_id, status, n_tokens)
else:
pass
time.sleep(5)
else:
print('Queue is full, sleeping for 10 seconds ...')
time.sleep(10)
def mark_as_processed(self, source_id, user_id, status, n_tokens):
'''Mark file as processed
'''
update_q = f'UPDATE data_sources SET status = ?, n_tokens = ? WHERE source_id = ? AND user_id = ?'
sql_upsert_sources = f'INSERT INTO usage (user_id, n_chatbots, n_sources, n_tokens, n_messages) VALUES (?, 0, 1, ?, 0) ON CONFLICT(user_id) DO UPDATE SET n_sources = n_sources + 1, n_tokens = n_tokens + ?'
with self.lock:
try:
write_to_db(update_q, [status, int(n_tokens), source_id, user_id])
write_to_db(sql_upsert_sources, [user_id, int(n_tokens), int(n_tokens)])
except Exception as e:
logger.error(f'Error updating {source_id} for {user_id}: {e}')
self.processing_cnt -= 1