diff --git a/codegen/README.md b/codegen/README.md new file mode 100644 index 00000000..6c645ca3 --- /dev/null +++ b/codegen/README.md @@ -0,0 +1 @@ +codegen using `ast` to generate `sync_api` from `async_api` diff --git a/scratchattach/asyncio/__init__.py b/codegen/__init__.py similarity index 100% rename from scratchattach/asyncio/__init__.py rename to codegen/__init__.py diff --git a/codegen/main.py b/codegen/main.py new file mode 100644 index 00000000..9cc8cac3 --- /dev/null +++ b/codegen/main.py @@ -0,0 +1,231 @@ +# i am really unsure on how this should be implemented +from __future__ import annotations +import contextlib +import os +from copy import deepcopy +from typing import Any, TypedDict, cast, Optional, TYPE_CHECKING +import ast +from pathlib import Path +import json +import subprocess + +if TYPE_CHECKING: + from _typeshed import StrPath + + +PRE_CODEGEN_NAME = "IS_PRE_CODEGEN" +STATICALLY_ASYNC_NAME = "IS_ASYNC" +DYNAMICALLY_ASYNC_NAME = "IS_ASYNC" + + +class CodegenConfig(TypedDict): + sync_target_directory: str + async_target_directory: str + exclude: list[str] + include_directories: list[str] + + +class AsyncCodegenNodeTransformer(ast.NodeTransformer): + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + self.generic_visit(node) + + if node.name.endswith("_prim_sync"): + return None + + return node + + def _is_statically_async_literal(self, node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value == STATICALLY_ASYNC_NAME + + def _is_pre_codegen_literal(self, node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value == PRE_CODEGEN_NAME + + def _match_static_condition(self, test: ast.AST) -> bool | None: + if self._is_statically_async_literal(test): + return True + + if self._is_pre_codegen_literal(test): + return False + + if ( + isinstance(test, ast.UnaryOp) + and isinstance(test.op, ast.Not) + and self._is_statically_async_literal(test.operand) + ): + return False + + if ( + isinstance(test, ast.UnaryOp) + and isinstance(test.op, ast.Not) + and self._is_pre_codegen_literal(test.operand) + ): + return True + + return None + + def visit_If(self, node: ast.If) -> Any: + self.generic_visit(node) + + if (condition_value := self._match_static_condition(node.test)) is not None: + return node.body if condition_value else node.orelse + + return node + + +class SyncCodegenNodeTransformer(ast.NodeTransformer): + def visit_Assign(self, node: ast.Assign) -> Any: + self.generic_visit(node) + + if node.targets: + first_target = node.targets[0] + if isinstance(first_target, ast.Name) and first_target.id == DYNAMICALLY_ASYNC_NAME: + node.value = ast.Constant(value=False, kind=None) + + return node + + def visit_Await(self, node: ast.Await) -> Any: + self.generic_visit(node) + return node.value + + def visit_AsyncFor(self, node: ast.AsyncFor) -> Any: + self.generic_visit(node) + new_node = ast.For(**{field: getattr(node, field) for field in node._fields}) + return ast.copy_location(new_node, node) + + def visit_AsyncWith(self, node: ast.AsyncWith) -> Any: + self.generic_visit(node) + new_node = ast.With(**{field: getattr(node, field) for field in node._fields}) + return ast.copy_location(new_node, node) + + def visit_comprehension(self, node: ast.comprehension) -> Any: + self.generic_visit(node) + node.is_async = 0 + return node + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: + self.generic_visit(node) + + # primitive functions are implemented as sync and as async so the async variant can be dropped + if node.name.endswith("_prim"): + return None + + new_node = ast.FunctionDef(**{field: getattr(node, field) for field in node._fields}) + + return ast.copy_location(new_node, node) + + def _is_statically_async_literal(self, node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value == STATICALLY_ASYNC_NAME + + def _is_pre_codegen_literal(self, node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value == PRE_CODEGEN_NAME + + def _match_static_condition(self, test: ast.AST) -> bool | None: + if self._is_statically_async_literal(test): + return False + + if self._is_pre_codegen_literal(test): + return False + + if ( + isinstance(test, ast.UnaryOp) + and isinstance(test.op, ast.Not) + and self._is_statically_async_literal(test.operand) + ): + return True + + if ( + isinstance(test, ast.UnaryOp) + and isinstance(test.op, ast.Not) + and self._is_pre_codegen_literal(test.operand) + ): + return True + + return None + + def visit_If(self, node: ast.If) -> Any: + self.generic_visit(node) + + if (condition_value := self._match_static_condition(node.test)) is not None: + return node.body if condition_value else node.orelse + + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + self.generic_visit(node) + + if node.name.endswith("_prim_sync"): + node.name = node.name.removesuffix("_sync") + + return node + + +def codegen_for_ast(ast: ast.AST) -> tuple[ast.AST, ast.AST]: + ast_2 = deepcopy(ast) + return ( + SyncCodegenNodeTransformer().generic_visit(ast), + AsyncCodegenNodeTransformer().generic_visit(ast_2), + ) + + +def codegen_for_file(file: Path) -> tuple[ast.AST, ast.AST]: + code = file.read_text() + return codegen_for_ast(ast.parse(code)) + + +def codegen_for_whole_directory(directory: "StrPath"): + directory = Path(directory).resolve() + items = {path.name: path for path in directory.iterdir()} + codegen_config: CodegenConfig + try: + codegen_config = cast( + "CodegenConfig", json.loads(items.pop("codegen_config.json").read_text()) + ) + except KeyError: + codegen_config = CodegenConfig( + sync_target_directory=str( + directory.with_stem(f"{directory.stem}_sync"), + ), + async_target_directory=str( + directory.with_stem(f"{directory.stem}_async"), + ), + exclude=[], + include_directories=[], + ) + sync_target_directory = directory / codegen_config["sync_target_directory"] + async_target_directory = directory / codegen_config["async_target_directory"] + sync_target_directory.mkdir(parents=True, exist_ok=True) + async_target_directory.mkdir(parents=True, exist_ok=True) + exclusions = {(directory / exclusion).resolve() for exclusion in codegen_config["exclude"]} + for path in items.values(): + path = path.resolve() + if path.suffix.lower() != ".py": + continue + if not path.is_file(): + continue + if path in exclusions: + continue + (sync_ast, async_ast) = codegen_for_file(path) + (sync_code, async_code) = (ast.unparse(sync_ast), ast.unparse(async_ast)) + (sync_target_directory / path.name).write_text(sync_code) + (async_target_directory / path.name).write_text(async_code) + subprocess.run( + [ + "python", + "-m", + "ruff", + "format", + str(sync_target_directory.resolve()), + str(async_target_directory.resolve()), + ], + capture_output=True, + text=True, + ) + for included_dir in codegen_config["include_directories"]: + codegen_for_whole_directory(directory / included_dir) + + +def main(): ... + + +if __name__ == "__main__": + main() diff --git a/codegen/pyproject.toml b/codegen/pyproject.toml new file mode 100644 index 00000000..f4c468ff --- /dev/null +++ b/codegen/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "codegen" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12.12" +dependencies = ["ruff"] diff --git a/codegen/test_codegen/async_out/test.py b/codegen/test_codegen/async_out/test.py new file mode 100644 index 00000000..48aceb89 --- /dev/null +++ b/codegen/test_codegen/async_out/test.py @@ -0,0 +1,58 @@ +from typing import Iterable, TypeVar, ParamSpec, Generic, Any, TYPE_CHECKING, Optional, cast +from collections.abc import Callable +import time + +IS_ASYNC = True +if IS_ASYNC: + from collections.abc import Awaitable + import asyncio +else: + import threading + from dataclasses import dataclass +P = ParamSpec("P") +O = TypeVar("O") + + +def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> O: + return function(*args, **kwargs) + + +async def sleep_prim(delay: int | float): + await asyncio.sleep(delay) + + +T = TypeVar("T") + + +async def gather_prim(*tasks: Awaitable[T]) -> list[T]: + return await asyncio.gather(*tasks) + + +async def fetch_user_data(user_id: int, delay: int) -> dict: + print(f"[{time.strftime('%X')}] Task {user_id}: Starting request (takes {delay}s)...") + await sleep_prim(delay) + print(f"[{time.strftime('%X')}] Task {user_id}: Finished request!") + return {"user_id": user_id, "status": "success"} + + +async def main(): + start_time = time.perf_counter() + print("--- Fetching data concurrently ---") + coroutines = [ + create_task(fetch_user_data, user_id=1, delay=2), + create_task(fetch_user_data, user_id=2, delay=3), + create_task(fetch_user_data, user_id=3, delay=1), + ] + results = await gather_prim(*coroutines) + end_time = time.perf_counter() + total_time = end_time - start_time + print("\n--- All tasks complete ---") + print(f"Total time taken: {total_time:.2f} seconds") + print("Results:", results) + + +if __name__ == "__main__": + if IS_ASYNC: + asyncio.run(main()) + else: + main() diff --git a/codegen/test_codegen/codegen_config.json b/codegen/test_codegen/codegen_config.json new file mode 100644 index 00000000..6e8371cc --- /dev/null +++ b/codegen/test_codegen/codegen_config.json @@ -0,0 +1,6 @@ +{ + "sync_target_directory": "./sync_out", + "async_target_directory": "./async_out", + "exclude": [], + "include_directories": [] +} \ No newline at end of file diff --git a/codegen/test_codegen/sync_out/test.py b/codegen/test_codegen/sync_out/test.py new file mode 100644 index 00000000..72e951c1 --- /dev/null +++ b/codegen/test_codegen/sync_out/test.py @@ -0,0 +1,76 @@ +from typing import Iterable, TypeVar, ParamSpec, Generic, Any, TYPE_CHECKING, Optional, cast +from collections.abc import Callable +import time + +IS_ASYNC = False +if IS_ASYNC: + from collections.abc import Awaitable + import asyncio +else: + import threading + from dataclasses import dataclass +P = ParamSpec("P") +O = TypeVar("O") + + +@dataclass +class Task(Generic[O]): + out: Optional[O] + thread: threading.Thread + + +def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> Task[O]: + task: Task[O] = Task(None, cast(threading.Thread, None)) + + def wrapper(*args, **kwargs): + task.out = function(*args, **kwargs) + + task.thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs) + return task + + +def sleep_prim(delay: int | float): + time.sleep(delay) + + +T = TypeVar("T") + + +def gather_prim(*tasks: Task[T]) -> list[T]: + values: list[T] = [] + for task in tasks: + task.thread.start() + for task in tasks: + task.thread.join() + values.append(cast(T, task.out)) + return values + + +def fetch_user_data(user_id: int, delay: int) -> dict: + print(f"[{time.strftime('%X')}] Task {user_id}: Starting request (takes {delay}s)...") + sleep_prim(delay) + print(f"[{time.strftime('%X')}] Task {user_id}: Finished request!") + return {"user_id": user_id, "status": "success"} + + +def main(): + start_time = time.perf_counter() + print("--- Fetching data concurrently ---") + coroutines = [ + create_task(fetch_user_data, user_id=1, delay=2), + create_task(fetch_user_data, user_id=2, delay=3), + create_task(fetch_user_data, user_id=3, delay=1), + ] + results = gather_prim(*coroutines) + end_time = time.perf_counter() + total_time = end_time - start_time + print("\n--- All tasks complete ---") + print(f"Total time taken: {total_time:.2f} seconds") + print("Results:", results) + + +if __name__ == "__main__": + if IS_ASYNC: + asyncio.run(main()) + else: + main() diff --git a/codegen/test_codegen/test.py b/codegen/test_codegen/test.py new file mode 100644 index 00000000..81960370 --- /dev/null +++ b/codegen/test_codegen/test.py @@ -0,0 +1,110 @@ +from typing import Iterable, TypeVar, ParamSpec, Generic, Any, TYPE_CHECKING, Optional, cast +from collections.abc import Callable + +import time + +IS_ASYNC = True +if IS_ASYNC: + from collections.abc import Awaitable + import asyncio +else: + import threading + from dataclasses import dataclass + + +P = ParamSpec("P") +O = TypeVar("O") + +if "IS_PRE_CODEGEN": + if TYPE_CHECKING: + import threading + from dataclasses import dataclass + + @dataclass + class Task(Generic[O]): + out: Optional[O] + thread: threading.Thread + + +if "IS_ASYNC": + + def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> O: + return function(*args, **kwargs) +else: + + @dataclass + class Task(Generic[O]): # type: ignore[no-redef] + out: Optional[O] + thread: threading.Thread + + def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> Task[O]: # type: ignore[misc] + task: Task[O] = Task(None, cast(threading.Thread, None)) # type: ignore[arg-type] + + def wrapper(*args, **kwargs): + task.out = function(*args, **kwargs) + + task.thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs) + return task + + +def sleep_prim_sync(delay: int | float): + time.sleep(delay) + + +async def sleep_prim(delay: int | float): + await asyncio.sleep(delay) + + +T = TypeVar("T") + + +def gather_prim_sync(*tasks: Task[T]) -> list[T]: + values: list[T] = [] + for task in tasks: + task.thread.start() + + for task in tasks: + task.thread.join() + values.append(cast(T, task.out)) + + return values + + +async def gather_prim(*tasks: Awaitable[T]) -> list[T]: + return await asyncio.gather(*tasks) + + +async def fetch_user_data(user_id: int, delay: int) -> dict: + print(f"[{time.strftime('%X')}] Task {user_id}: Starting request (takes {delay}s)...") + + await sleep_prim(delay) + + print(f"[{time.strftime('%X')}] Task {user_id}: Finished request!") + return {"user_id": user_id, "status": "success"} + + +async def main(): + start_time = time.perf_counter() + print("--- Fetching data concurrently ---") + + coroutines = [ + create_task(fetch_user_data, user_id=1, delay=2), + create_task(fetch_user_data, user_id=2, delay=3), + create_task(fetch_user_data, user_id=3, delay=1), + ] + + results = await gather_prim(*coroutines) + + end_time = time.perf_counter() + total_time = end_time - start_time + + print("\n--- All tasks complete ---") + print(f"Total time taken: {total_time:.2f} seconds") + print("Results:", results) + + +if __name__ == "__main__": + if IS_ASYNC: + asyncio.run(main()) + else: + main() diff --git a/pyproject.toml b/pyproject.toml index b3a8d4fb..397a70ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "browser_cookie3", "aiohttp", "rich", + "httpx>=0.28.1", ] readme = "README.md" license = "MIT" @@ -97,6 +98,11 @@ max-complexity = 10 [tool.uv] config-settings = { editable_mode = "compat" } +[tool.uv.workspace] +members = [ + "codegen", +] + [tool.setuptools.packages.find] where = ["."] include = ["scratchattach*"] diff --git a/scratchattach/_core/__init__.py b/scratchattach/_core/__init__.py new file mode 100644 index 00000000..f0dc916a --- /dev/null +++ b/scratchattach/_core/__init__.py @@ -0,0 +1,4 @@ +""" +Async implementations for blocking operations that are not to be used themselves. +The implementations are read by the global codegen module and async and sync variations are generated. +""" \ No newline at end of file diff --git a/scratchattach/_core/codegen_config.json b/scratchattach/_core/codegen_config.json new file mode 100644 index 00000000..57066d3d --- /dev/null +++ b/scratchattach/_core/codegen_config.json @@ -0,0 +1,8 @@ +{ + "sync_target_directory": "../sync_api", + "async_target_directory": "../async_api", + "exclude": [ + "./__init__.py" + ], + "include_directories": ["./primitives"] +} \ No newline at end of file diff --git a/scratchattach/_core/primitives/__init__.py b/scratchattach/_core/primitives/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scratchattach/_core/primitives/codegen_config.json b/scratchattach/_core/primitives/codegen_config.json new file mode 100644 index 00000000..673f7408 --- /dev/null +++ b/scratchattach/_core/primitives/codegen_config.json @@ -0,0 +1,8 @@ +{ + "sync_target_directory": "../../sync_api/primitives", + "async_target_directory": "../../async_api/primitives", + "exclude": [ + "./__init__.py" + ], + "include_directories": [] +} \ No newline at end of file diff --git a/scratchattach/_core/primitives/utils.py b/scratchattach/_core/primitives/utils.py new file mode 100644 index 00000000..9ab56cff --- /dev/null +++ b/scratchattach/_core/primitives/utils.py @@ -0,0 +1,316 @@ +from __future__ import annotations +import _asyncio +from collections.abc import Callable +from typing import Union, ParamSpec, TypeVar, Generic, Any, cast, Optional, overload, Literal +import time + +if "IS_PRE_CODEGEN": + CTYPES_PRESENT = True + import ctypes + import threading + import concurrent.futures + import asyncio + from collections.abc import Awaitable, Coroutine +else: + if "IS_ASYNC": + CTYPES_PRESENT = False + import asyncio + from collections.abc import Awaitable, Coroutine + else: + try: + import ctypes + + CTYPES_PRESENT = True + except Exception: + CTYPES_PRESENT = False + import threading + import concurrent.futures + + +def sleep_prim_sync(delay: Union[int, float]): + time.sleep(delay) + + +async def sleep_prim(delay: Union[int, float]): + await asyncio.sleep(delay) + + +P = ParamSpec("P") +O = TypeVar("O", covariant=True) + + +class Task(Generic[P, O]): + function: Callable[P, O] + args: Any + kwargs: Any + available: bool + + +class LaunchedTask(Generic[P, O]): + task: Task[P, O] + if "IS_PRE_CODEGEN": + _out: O + _task: asyncio.Task[Any] + _thread: threading.Thread + else: + if "IS_ASYNC": + _task: asyncio.Task[Any] # type: ignore[no-redef] + else: + _out: O # type: ignore[no-redef] + _thread: threading.Thread # type: ignore[no-redef] + + +def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> Task[P, O]: + task: Task[P, O] = Task() + task.function = function + task.args = args + task.kwargs = kwargs + task.available = True + return task + + +def gather_concurrently_prim_sync(*tasks: Task[Any, O]) -> list[O]: + with concurrent.futures.ThreadPoolExecutor() as executor: + return [cast(O, i) for i in executor.map(lambda x: x.function(*x.args, **x.kwargs), tasks)] + + +async def gather_concurrently_prim(*tasks: Task[Any, Awaitable[O]]) -> list[O]: + for task in tasks: + if not task.available: + raise ValueError("Task is already used.") + task.available = False + return await asyncio.gather(*(task.function(*task.args, **task.kwargs) for task in tasks)) + + +def launch_concurrently_prim_sync(task: Task[P, O]) -> LaunchedTask[P, O]: + launched_task: LaunchedTask[P, O] = LaunchedTask() + + def wrap_function(): + launched_task._out = task.function(*task.args, **task.kwargs) + + thread = threading.Thread(target=wrap_function) + thread.start() + launched_task.task = task + launched_task._thread = thread + return launched_task + + +A = TypeVar("A") +B = TypeVar("B") + + +async def launch_concurrently_prim( + task: Task[P, Coroutine[A, B, O]], +) -> LaunchedTask[P, Coroutine[A, B, O]]: + _task = asyncio.create_task(task.function(*task.args, **task.kwargs)) + launched_task: LaunchedTask[P, Coroutine[A, B, O]] = LaunchedTask() + launched_task.task = task + launched_task._task = _task + return launched_task + + +@overload +def join_launched_task_prim_sync(task: LaunchedTask[P, O]) -> O: + pass + + +@overload +def join_launched_task_prim_sync( + task: LaunchedTask[P, O], timeout: Union[float, int] +) -> Optional[O]: + pass + + +def join_launched_task_prim_sync( + task: LaunchedTask[P, O], timeout: Optional[Union[float, int]] = None +) -> Optional[O]: + task._thread.join(timeout) + if task._thread.is_alive(): + return None + return task._out + + +@overload +async def join_launched_task_prim(task: LaunchedTask[P, Coroutine[Any, Any, O]]) -> O: + pass + + +@overload +async def join_launched_task_prim( + task: LaunchedTask[P, Coroutine[Any, Any, O]], timeout: Union[float, int] +) -> Optional[O]: + pass + + +async def join_launched_task_prim( + task: LaunchedTask[P, Coroutine[Any, Any, O]], timeout: Optional[Union[float, int]] = None +) -> Optional[O]: + try: + return await asyncio.wait_for(asyncio.shield(task._task), timeout) + except TimeoutError: + return None + + +if "IS_PRE_CODEGEN": + + def _raise_in_thread(thread: threading.Thread, exc_type: type[BaseException]) -> None: ... + + +if not "IS_ASYNC": + + def _raise_in_thread(thread: threading.Thread, exc_type: type[BaseException]) -> None: + if not CTYPES_PRESENT: + raise NotImplementedError( + "Sending exceptions to threads is not supported in this Python version." + ) + + if not thread.is_alive(): + raise ValueError("Thread is not alive.") + + thread_id = thread.ident + if thread_id is None: + raise ValueError("Thread has no ident.") + + result = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_ulong(thread_id), + ctypes.py_object(exc_type), + ) + + if result == 0: + raise ValueError("Thread ident is invalid.") + + if result > 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_ulong(thread_id), + None, + ) + raise SystemError("PyThreadState_SetAsyncExc failed.") + + +@overload +def kill_launched_task_prim_sync( + task: LaunchedTask[P, O], *, exception_interval: Union[float, int] = 0.1 +) -> Literal[True]: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +@overload +def kill_launched_task_prim_sync( + task: LaunchedTask[P, O], + timeout: Union[float, int], + *, + exception_interval: Union[float, int] = 0.1, +) -> bool: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +def kill_launched_task_prim_sync( + task: LaunchedTask[P, O], + timeout: Optional[Union[float, int]] = None, + *, + exception_interval: Union[float, int] = 0.1, +) -> bool: + has_timeout, timeout_end = ( + (True, time.time() + timeout) if timeout is not None else (False, None) + ) + while (not has_timeout) or (timeout_end is not None and time.time() <= timeout_end): + if not task._thread.is_alive(): + break + _raise_in_thread(task._thread, SystemExit) + time.sleep(exception_interval) + if has_timeout and timeout_end is not None and time.time() > timeout_end: + return False + return True + + +@overload +async def kill_launched_task_prim( + task: LaunchedTask[P, O], *, exception_interval: Union[float, int] = 0.1 +) -> Literal[True]: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +@overload +async def kill_launched_task_prim( + task: LaunchedTask[P, O], + timeout: Union[float, int], + *, + exception_interval: Union[float, int] = 0.1, +) -> bool: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +async def kill_launched_task_prim( + task: LaunchedTask[P, O], + timeout: Optional[Union[float, int]] = None, + *, + exception_interval: Union[float, int] = 0.1, +) -> bool: + has_timeout, timeout_end = ( + (True, time.time() + timeout) if timeout is not None else (False, None) + ) + if task._task.cancel(): + return True + while (not has_timeout) or (timeout_end is not None and time.time() <= timeout_end): + if not task._task.done(): + break + task._task.set_exception(SystemExit) + await asyncio.sleep(exception_interval) + if has_timeout and timeout_end is not None and time.time() > timeout_end: + return False + return True + + +# async def task_1(): +# print("Starting task 1...") +# await sleep_prim(2) +# print("Task 1 done.") + + +# async def task_2(msg: Any): +# print("Starting task 2...") +# await sleep_prim(1) +# print("Task 2 says:", msg) +# print("Task 2 done.") + + +# async def task_3(delay: Union[float, int]): +# print("Starting task 3...") +# await sleep_prim(delay) +# print("Task 3 done.") + + +# async def main(): +# await gather_concurrently_prim( +# create_task(task_1), create_task(task_2, msg="Hello there!"), create_task(task_3, 3) +# ) +# print("Launching task...") +# task = await launch_concurrently_prim(create_task(task_3, 5)) +# print("Launched task.") +# await sleep_prim(4) +# print("Joining task...") +# await join_launched_task_prim(task) +# print("Task done.") + + +# if __name__ == "__main__": +# if "IS_ASYNC": +# asyncio.run(main()) +# else: +# main() diff --git a/scratchattach/async_api/__init__.py b/scratchattach/async_api/__init__.py new file mode 100644 index 00000000..a4e88681 --- /dev/null +++ b/scratchattach/async_api/__init__.py @@ -0,0 +1,17 @@ +import httpx + +client = httpx.AsyncClient() + + +async def get_home() -> str: + resp = await client.get("https://scratch.mit.edu") + return resp.text + + +if __name__ == "__main__": + import asyncio + + async def main(): + print(await get_home()) + + asyncio.run(main()) diff --git a/scratchattach/async_api/primitives/utils.py b/scratchattach/async_api/primitives/utils.py new file mode 100644 index 00000000..7b93addc --- /dev/null +++ b/scratchattach/async_api/primitives/utils.py @@ -0,0 +1,113 @@ +from __future__ import annotations +import _asyncio +from collections.abc import Callable +from typing import Union, ParamSpec, TypeVar, Generic, Any, cast, Optional, overload, Literal +import time + +CTYPES_PRESENT = False +import asyncio +from collections.abc import Awaitable, Coroutine + + +async def sleep_prim(delay: Union[int, float]): + await asyncio.sleep(delay) + + +P = ParamSpec("P") +O = TypeVar("O", covariant=True) + + +class Task(Generic[P, O]): + function: Callable[P, O] + args: Any + kwargs: Any + available: bool + + +class LaunchedTask(Generic[P, O]): + task: Task[P, O] + _task: asyncio.Task[Any] + + +def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> Task[P, O]: + task: Task[P, O] = Task() + task.function = function + task.args = args + task.kwargs = kwargs + task.available = True + return task + + +async def gather_concurrently_prim(*tasks: Task[Any, Awaitable[O]]) -> list[O]: + for task in tasks: + if not task.available: + raise ValueError("Task is already used.") + task.available = False + return await asyncio.gather(*(task.function(*task.args, **task.kwargs) for task in tasks)) + + +A = TypeVar("A") +B = TypeVar("B") + + +async def launch_concurrently_prim(task: Task[P, Coroutine[A, B, O]]) -> LaunchedTask[P, Coroutine[A, B, O]]: + _task = asyncio.create_task(task.function(*task.args, **task.kwargs)) + launched_task: LaunchedTask[P, Coroutine[A, B, O]] = LaunchedTask() + launched_task.task = task + launched_task._task = _task + return launched_task + + +@overload +async def join_launched_task_prim(task: LaunchedTask[P, Coroutine[Any, Any, O]]) -> O: + pass + + +@overload +async def join_launched_task_prim(task: LaunchedTask[P, Coroutine[Any, Any, O]], timeout: Union[float, int]) -> Optional[O]: + pass + + +async def join_launched_task_prim( + task: LaunchedTask[P, Coroutine[Any, Any, O]], timeout: Optional[Union[float, int]] = None +) -> Optional[O]: + try: + return await asyncio.wait_for(asyncio.shield(task._task), timeout) + except TimeoutError: + return None + + +@overload +async def kill_launched_task_prim(task: LaunchedTask[P, O], *, exception_interval: Union[float, int] = 0.1) -> Literal[True]: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +@overload +async def kill_launched_task_prim( + task: LaunchedTask[P, O], timeout: Union[float, int], *, exception_interval: Union[float, int] = 0.1 +) -> bool: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +async def kill_launched_task_prim( + task: LaunchedTask[P, O], timeout: Optional[Union[float, int]] = None, *, exception_interval: Union[float, int] = 0.1 +) -> bool: + has_timeout, timeout_end = (True, time.time() + timeout) if timeout is not None else (False, None) + if task._task.cancel(): + return True + while not has_timeout or (timeout_end is not None and time.time() <= timeout_end): + if not task._task.done(): + break + task._task.set_exception(SystemExit) + await asyncio.sleep(exception_interval) + if has_timeout and timeout_end is not None and (time.time() > timeout_end): + return False + return True diff --git a/scratchattach/site/session.py b/scratchattach/site/session.py index bf182994..2a4cdf13 100644 --- a/scratchattach/site/session.py +++ b/scratchattach/site/session.py @@ -1495,5 +1495,6 @@ def login_from_browser(browser: Browser = ANY): """ cookies = cookies_from_browser(browser) if "scratchsessionsid" in cookies: - return login_by_id(cookies["scratchsessionsid"]) + with suppress_login_warning(): + return login_by_id(cookies["scratchsessionsid"]) raise ValueError("Not enough data to log in.") diff --git a/scratchattach/sync_api/primitives/utils.py b/scratchattach/sync_api/primitives/utils.py new file mode 100644 index 00000000..2bbb6198 --- /dev/null +++ b/scratchattach/sync_api/primitives/utils.py @@ -0,0 +1,133 @@ +from __future__ import annotations +import _asyncio +from collections.abc import Callable +from typing import Union, ParamSpec, TypeVar, Generic, Any, cast, Optional, overload, Literal +import time + +try: + import ctypes + + CTYPES_PRESENT = True +except Exception: + CTYPES_PRESENT = False +import threading +import concurrent.futures + + +def sleep_prim(delay: Union[int, float]): + time.sleep(delay) + + +P = ParamSpec("P") +O = TypeVar("O", covariant=True) + + +class Task(Generic[P, O]): + function: Callable[P, O] + args: Any + kwargs: Any + available: bool + + +class LaunchedTask(Generic[P, O]): + task: Task[P, O] + _out: O + _thread: threading.Thread + + +def create_task(function: Callable[P, O], *args: P.args, **kwargs: P.kwargs) -> Task[P, O]: + task: Task[P, O] = Task() + task.function = function + task.args = args + task.kwargs = kwargs + task.available = True + return task + + +def gather_concurrently_prim(*tasks: Task[Any, O]) -> list[O]: + with concurrent.futures.ThreadPoolExecutor() as executor: + return [cast(O, i) for i in executor.map(lambda x: x.function(*x.args, **x.kwargs), tasks)] + + +def launch_concurrently_prim(task: Task[P, O]) -> LaunchedTask[P, O]: + launched_task: LaunchedTask[P, O] = LaunchedTask() + + def wrap_function(): + launched_task._out = task.function(*task.args, **task.kwargs) + + thread = threading.Thread(target=wrap_function) + thread.start() + launched_task.task = task + launched_task._thread = thread + return launched_task + + +A = TypeVar("A") +B = TypeVar("B") + + +@overload +def join_launched_task_prim(task: LaunchedTask[P, O]) -> O: + pass + + +@overload +def join_launched_task_prim(task: LaunchedTask[P, O], timeout: Union[float, int]) -> Optional[O]: + pass + + +def join_launched_task_prim(task: LaunchedTask[P, O], timeout: Optional[Union[float, int]] = None) -> Optional[O]: + task._thread.join(timeout) + if task._thread.is_alive(): + return None + return task._out + + +def _raise_in_thread(thread: threading.Thread, exc_type: type[BaseException]) -> None: + if not CTYPES_PRESENT: + raise NotImplementedError("Sending exceptions to threads is not supported in this Python version.") + if not thread.is_alive(): + raise ValueError("Thread is not alive.") + thread_id = thread.ident + if thread_id is None: + raise ValueError("Thread has no ident.") + result = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(thread_id), ctypes.py_object(exc_type)) + if result == 0: + raise ValueError("Thread ident is invalid.") + if result > 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(thread_id), None) + raise SystemError("PyThreadState_SetAsyncExc failed.") + + +@overload +def kill_launched_task_prim(task: LaunchedTask[P, O], *, exception_interval: Union[float, int] = 0.1) -> Literal[True]: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +@overload +def kill_launched_task_prim( + task: LaunchedTask[P, O], timeout: Union[float, int], *, exception_interval: Union[float, int] = 0.1 +) -> bool: + """ + Sends exceptions to the underlying concurrency primitive. + May also try to use the recommended way of cancelling the primitive if there is one. + Returns whether the task was actually killed. + """ + + +def kill_launched_task_prim( + task: LaunchedTask[P, O], timeout: Optional[Union[float, int]] = None, *, exception_interval: Union[float, int] = 0.1 +) -> bool: + has_timeout, timeout_end = (True, time.time() + timeout) if timeout is not None else (False, None) + while not has_timeout or (timeout_end is not None and time.time() <= timeout_end): + if not task._thread.is_alive(): + break + _raise_in_thread(task._thread, SystemExit) + time.sleep(exception_interval) + if has_timeout and timeout_end is not None and (time.time() > timeout_end): + return False + return True diff --git a/tests/test_memberships.py b/tests/test_memberships.py index 72dc5fc9..80aea3cc 100644 --- a/tests/test_memberships.py +++ b/tests/test_memberships.py @@ -13,8 +13,8 @@ def test_memberships(): u2 = sa.get_user("ceebee") assert u2.is_member - assert u2.has_ears - assert u2.has_badge() + assert not u2.has_ears + assert not u2.has_badge() u3 = sa.get_user("scratchattachv2") assert not u3.is_member diff --git a/uv.lock b/uv.lock index 6dae1154..32ce3963 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,12 @@ version = 1 revision = 3 -requires-python = ">=3.12" +requires-python = ">=3.12.12" + +[manifest] +members = [ + "codegen", + "scratchattach", +] [[package]] name = "aiohappyeyeballs" @@ -109,6 +115,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "anyio" +version = "4.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc", size = 231622, upload-time = "2026-03-24T12:59:09.671Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -224,6 +243,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] +[[package]] +name = "codegen" +version = "0.1.0" +source = { virtual = "codegen" } + [[package]] name = "frozenlist" version = "1.8.0" @@ -313,6 +337,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/9a/e35b4a917281c0b8419d4207f4334c8e8c5dbf4f3f5f9ada73958d937dcc/frozenlist-1.8.0-py3-none-any.whl", hash = "sha256:0c18a16eab41e82c295618a77502e17b195883241c563b00f0aa5106fc4eaa0d", size = 13409, upload-time = "2025-10-06T05:38:16.721Z" }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + [[package]] name = "idna" version = "3.11" @@ -782,6 +843,7 @@ dependencies = [ { name = "aiohttp" }, { name = "browser-cookie3" }, { name = "bs4" }, + { name = "httpx" }, { name = "requests" }, { name = "rich" }, { name = "simplewebsocketserver" }, @@ -807,6 +869,7 @@ requires-dist = [ { name = "aiohttp" }, { name = "browser-cookie3" }, { name = "bs4" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "lark", marker = "extra == 'lark'" }, { name = "requests" }, { name = "rich" },