Skip to content

Commit cfcb528

Browse files
committed
fix: per-profile file locking to prevent concurrent duplicate checkouts
1 parent 6c05cff commit cfcb528

3 files changed

Lines changed: 144 additions & 27 deletions

File tree

src/pybritive/britive_cli.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import __version__
2525
from .helpers import cloud_credential_printer as printer
2626
from .helpers.cache import Cache
27+
from .helpers.checkout_lock import CheckoutLock
2728
from .helpers.config import ConfigManager
2829
from .helpers.credentials import EncryptedFileCredentialManager, FileCredentialManager
2930
from .helpers.split import profile_split
@@ -792,6 +793,17 @@ def _checkout(
792793
}
793794
raise e
794795

796+
def _check_cache(self, passphrase: Optional[str], profile_name: str, mode: str) -> Optional[dict]:
797+
credentials = Cache(passphrase=passphrase).get_credentials(profile_name=profile_name, mode=mode)
798+
if credentials:
799+
expiration_timestamp_str = jmespath.search(
800+
expression=self.cachable_modes[mode]['expiration_jmespath'], data=credentials
801+
).replace('Z', '')
802+
expires = datetime.fromisoformat(expiration_timestamp_str)
803+
if datetime.utcnow() < expires:
804+
return credentials
805+
return None
806+
795807
@staticmethod
796808
def _should_check_force_renew(app, force_renew, console):
797809
return app in ['AWS', 'AWS Standalone'] and force_renew and not console
@@ -903,25 +915,14 @@ def _access_checkout(
903915
self._validate_justification(justification)
904916

905917
if mode in self.cachable_modes:
906-
self.silent = True # CANNOT output anything other than the expected JSON
907-
# we need to check the cache for the credentials first and then check to see if they are expired
908-
# if not simply return those credentials, if they are expired, continue to do an actual checkout
918+
self.silent = True
909919
app_type = self.cachable_modes[mode]['app_type']
910-
credentials = Cache(passphrase=passphrase).get_credentials(profile_name=alias or profile, mode=mode)
920+
credentials = self._check_cache(passphrase, alias or profile, mode)
911921
if credentials:
912-
expiration_timestamp_str = jmespath.search(
913-
expression=self.cachable_modes[mode]['expiration_jmespath'], data=credentials
914-
).replace('Z', '')
915-
expires = datetime.fromisoformat(expiration_timestamp_str)
916-
now = datetime.utcnow()
917-
if now >= expires: # check to ensure the credentials are still valid, if not, set to None and get new
918-
credentials = None
919-
else:
920-
cached_credentials_found = True
922+
cached_credentials_found = True
921923

922924
parts = self._split_profile_into_parts(profile)
923925

924-
# create this params once so we can use it multiple places
925926
params = {
926927
'app_name': parts['app'],
927928
'blocktime': blocktime,
@@ -936,30 +937,45 @@ def _access_checkout(
936937
'ticket_type': ticket_type,
937938
}
938939

939-
if not cached_credentials_found: # nothing found in cache, cache is expired, or not a cachable mode
940-
response = self._checkout(**params)
941-
app_type = self._get_app_type(response['appContainerId'])
942-
credentials = response['credentials']
943-
console_fallback = response.get('console-fallback')
940+
if not cached_credentials_found:
941+
if mode in self.cachable_modes and self.config.checkout_lock_enabled():
942+
with CheckoutLock(profile_key=alias or profile, mode=mode):
943+
credentials = self._check_cache(passphrase, alias or profile, mode)
944+
if credentials:
945+
cached_credentials_found = True
946+
else:
947+
response = self._checkout(**params)
948+
app_type = self._get_app_type(response['appContainerId'])
949+
credentials = response['credentials']
950+
console_fallback = response.get('console-fallback')
951+
Cache(passphrase=passphrase).save_credentials(
952+
profile_name=alias or profile, credentials=credentials, mode=mode
953+
)
954+
else:
955+
response = self._checkout(**params)
956+
app_type = self._get_app_type(response['appContainerId'])
957+
credentials = response['credentials']
958+
console_fallback = response.get('console-fallback')
959+
if mode in self.cachable_modes:
960+
Cache(passphrase=passphrase).save_credentials(
961+
profile_name=alias or profile, credentials=credentials, mode=mode
962+
)
944963

945-
# this handles the --force-renew flag
946-
# lets check to see if we should checkin this profile first and check it out again
947964
if self._should_check_force_renew(app_type, force_renew, console):
948965
expiration = datetime.fromisoformat(credentials['expirationTime'].replace('Z', ''))
949966
now = datetime.utcnow()
950967
diff = (expiration - now).total_seconds() / 60.0
951-
if diff < force_renew: # time to checkin the profile so we can refresh creds
968+
if diff < force_renew:
952969
self.print('checking in the profile to get renewed credentials....standby')
953970
self.checkin(profile=profile, console=console)
954971
response = self._checkout(**params)
955-
cached_credentials_found = False # need to write new creds to cache
956972
credentials = response['credentials']
957973
console_fallback = response.get('console-fallback')
974+
if mode in self.cachable_modes:
975+
Cache(passphrase=passphrase).save_credentials(
976+
profile_name=alias or profile, credentials=credentials, mode=mode
977+
)
958978

959-
if mode in self.cachable_modes and not cached_credentials_found:
960-
Cache(passphrase=passphrase).save_credentials(
961-
profile_name=alias or profile, credentials=credentials, mode=mode
962-
)
963979
return app_type, console_fallback, credentials, k8s_processor
964980

965981
def checkout(
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import hashlib
2+
import os
3+
import time
4+
from pathlib import Path
5+
from types import TracebackType
6+
from typing import Optional, Type
7+
8+
9+
class CheckoutLockTimeout(Exception):
10+
pass
11+
12+
13+
class _WouldBlock(Exception):
14+
pass
15+
16+
17+
try:
18+
import fcntl
19+
20+
def _lock_fd(fd: int) -> None:
21+
try:
22+
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
23+
except (OSError, IOError):
24+
raise _WouldBlock()
25+
26+
def _unlock_fd(fd: int) -> None:
27+
fcntl.flock(fd, fcntl.LOCK_UN)
28+
29+
except ImportError:
30+
import msvcrt
31+
32+
def _lock_fd(fd: int) -> None:
33+
try:
34+
msvcrt.locking(fd, msvcrt.LK_NBLCK, 1)
35+
except (OSError, IOError):
36+
raise _WouldBlock()
37+
38+
def _unlock_fd(fd: int) -> None:
39+
try:
40+
msvcrt.locking(fd, msvcrt.LK_UNLCK, 1)
41+
except (OSError, IOError):
42+
pass
43+
44+
45+
class CheckoutLock:
46+
def __init__(self, profile_key: str, mode: str, timeout: float = 120.0, poll_interval: float = 0.1) -> None:
47+
self.timeout: float = timeout
48+
self.poll_interval: float = poll_interval
49+
self._fd: Optional[int] = None
50+
51+
home = os.getenv('PYBRITIVE_HOME_DIR', str(Path.home()))
52+
lock_dir = Path(home) / '.britive' / 'locks'
53+
lock_dir.mkdir(parents=True, exist_ok=True)
54+
55+
lock_name = hashlib.sha256(f'{mode}:{profile_key}'.lower().encode('utf-8')).hexdigest()[:16]
56+
self.lock_path: str = str(lock_dir / f'{lock_name}.lock')
57+
58+
def acquire(self) -> None:
59+
self._fd = os.open(self.lock_path, os.O_CREAT | os.O_RDWR)
60+
deadline = time.monotonic() + self.timeout
61+
while True:
62+
try:
63+
_lock_fd(self._fd)
64+
return
65+
except _WouldBlock:
66+
if time.monotonic() >= deadline:
67+
os.close(self._fd)
68+
self._fd = None
69+
raise CheckoutLockTimeout(
70+
f'Timed out after {self.timeout}s waiting for checkout lock'
71+
)
72+
time.sleep(self.poll_interval)
73+
74+
def release(self) -> None:
75+
if self._fd is not None:
76+
try:
77+
_unlock_fd(self._fd)
78+
finally:
79+
os.close(self._fd)
80+
self._fd = None
81+
82+
def __enter__(self) -> 'CheckoutLock':
83+
self.acquire()
84+
return self
85+
86+
def __exit__(
87+
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
88+
) -> bool:
89+
self.release()
90+
return False

src/pybritive/helpers/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def coalesce(*arg):
4343
'auto_refresh_kube_config',
4444
'auto_refresh_profile_cache',
4545
'ca_bundle',
46+
'checkout_lock',
4647
'credential_backend',
4748
'default_tenant',
4849
'output_format',
@@ -281,6 +282,9 @@ def validate_global(self, section, fields):
281282
if field.replace('-', '_') == 'auto_refresh_kube_config' and value not in ['true', 'false']:
282283
error = f'Invalid {section} field {field} value {value} provided. Invalid value choice.'
283284
self.validation_error_messages.append(error)
285+
if field.replace('-', '_') == 'checkout_lock' and value not in ['true', 'false']:
286+
error = f'Invalid {section} field {field} value {value} provided. Invalid value choice.'
287+
self.validation_error_messages.append(error)
284288
if field == 'default_tenant':
285289
tenant_aliases_from_sections = [extract_tenant(t) for t in self.config if t.startswith('tenant-')]
286290
if value not in tenant_aliases_from_sections:
@@ -343,3 +347,10 @@ def auto_refresh_kube_config(self):
343347
'auto_refresh_kube_config', self.config.get('global', {}).get('auto-refresh-kube-config', 'false')
344348
)
345349
return value == 'true'
350+
351+
def checkout_lock_enabled(self) -> bool:
352+
self.load()
353+
value = self.config.get('global', {}).get(
354+
'checkout_lock', self.config.get('global', {}).get('checkout-lock', 'false')
355+
)
356+
return value == 'true'

0 commit comments

Comments
 (0)