Compare commits

...

2 Commits

4 changed files with 471 additions and 0 deletions

View File

@ -1,3 +1,4 @@
import select
import warnings
import platform
from typing import Any
@ -655,5 +656,20 @@ class AsyncSocket(AsyncResource):
except WantWrite:
await wait_writable(self._fd)
def is_readable(self):
"""
Returns whether the OS thinks this socket is
readable. For more info see https://github.com/python-trio/trio/issues/760
"""
match platform.system():
case "Windows":
readable, _, _ = select.select([self.socket], [], [], 0)
return bool(readable)
case _:
p = select.poll()
p.register(self.socket, select.POLLIN)
return bool(p.poll(0))
def __repr__(self):
return f"AsyncSocket({self.socket})"

445
tests/chrono.py Normal file
View File

@ -0,0 +1,445 @@
import time
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from threading import Lock, Event, Thread
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from timeit import default_timer
import structio
def parse_delay(delta: str) -> timedelta:
"""
Parse a delay string like '1y 2mo 3w 4d 5h 6m 7s'
into a timedelta object. Case and space insensitive
"""
second = 1
minute = second * 60
hour = minute * 60
day = hour * 24
week = day * 7
# 30 days is just the most commonly accepted
# duration for a month, even though 4 weeks
# also is, and they don't always match up in
# duration
month = day * 30
# No leap years or stuff like that. Sorry!
year = day * 365
units = {"s": second,
"m": minute,
"h": hour,
"d": day,
"w": week,
"mo": month,
"y": year}
# We need this to handle multi-character units
# like months
longest = max(len(k) for k in units.keys())
seconds = 0
default_unit = unit = units['d'] # Defaults to days
value = ""
# Clean up the string, we don't need this junk anyway
delta = delta.translate(str.maketrans({"\n": "", "\t": "", "\r": "", " ": ""})).lower()
i = 0
while i < len(delta):
ch = delta[i]
i += 1
# isdigit will return True even for things like numbers
# in Arabic and stuff. For my use case this is undesirable,
# but if you care about that, just swap isnumeric() for isdigit()
if ch.isnumeric():
value += ch
# Only restrict ourselves to ASCII letters. Again, this is specific
# to my use-case. Can edit as necessary
elif ch.isalpha() and ch.isascii():
unit_name = ch
j = i
# Try to parse as many characters of the unit as possible,
# but stop once we either reach the end of the string or
# the length of the current (potential) unit equals the
# length of the longest one we know
while j < len(delta) and len(unit_name) < longest:
if (char := delta[j]).isalpha() and char.isascii():
unit_name += char
else:
break
j += 1
# If the unit was more than one character long, we
# have to skip the next len(unit_name) - 1 characters,
# so we don't try to parse 'mo' and then 'o' again, for
# example. If the unit was one character long, this value
# will just be zero
i += len(unit_name) - 1
try:
unit = units[unit_name]
except KeyError:
raise ValueError(f"invalid unit {unit_name!r}")
if value:
seconds += unit * int(value)
value = ""
unit = default_unit
else:
# Unknown character
raise ValueError(f"Invalid character {ch!r} in time delta: {delta!r}")
if value:
seconds = unit * int(value)
return timedelta(seconds=seconds)
@dataclass(order=True)
class ScheduledTask:
"""
A scheduled task
"""
command: str = field(compare=True)
when: timedelta = field(compare=True)
last: datetime | None = field(compare=False, default=None)
def should_run(self):
if self.last is None:
return True
return (datetime.now() - self.last) >= self.when
class TaskScheduler(ABC):
"""
A task scheduler. The concurrency_limit argument
limits how many tasks can be run at the same time
"""
def __init__(self, concurrency_limit: int | None = None):
"""
Public object constructor
"""
self._running = False
self._stop = False
self._timer = default_timer
self.concurrency_limit = concurrency_limit
@property
def running(self) -> bool:
"""
Returns whether the event scheduler is running
"""
return self._running
@property
def time(self) -> float:
"""
Return the scheduler's internal wall clock
time, in seconds
"""
return self._timer()
@property
def stopping(self) -> bool:
"""
Returns whether the scheduler is in the process of
shutting down
"""
return self.running and self._stop
@property
def stopped(self) -> bool:
"""
Returns whether the scheduler has stopped
"""
return not self.running and self._stop
@abstractmethod
def schedule(self, command: str, when: timedelta) -> ScheduledTask:
"""
Schedules the given command to run with the chosen
interval. Duplicates are allowed. A ScheduledTask
object is returned
"""
return NotImplemented()
@abstractmethod
def unschedule(self, task: ScheduledTask):
"""
Unschedules the command associated to the given
ScheduledTask object. Note that this only removes
the command associate with this exact object, not
any of its duplicates (for that, use unschedule_all).
If the given task is not scheduled, this method does
nothing
"""
return NotImplemented()
@abstractmethod
def unschedule_all(self, task: ScheduledTask) -> int:
"""
Unschedules all commands associated to the
given ScheduledTask object. Unlike unschedule(),
this method checks for object equality rather than
identity, so it removes all instances of the given
command that match the given one (meaning it can be
instantiated 'ex-novo' without needing the specific
object returned by schedule()). Returns the number
of tasks that were unscheduled
"""
return NotImplemented()
@abstractmethod
def start(self):
"""
Start the event scheduler. Unless explicitly
stated elsewhere, tasks can be scheduled both
before and during the call to run(), and they
will be executed as expected. Tasks scheduled
after calling stop() will be executed at the
next call to run(). If the scheduler is already
running, RuntimeError is raised. Note that, depending
on the underlying implementation, this method call may
block until the event loop is stopped by calling stop()
(usually by another thread or task)
"""
return NotImplemented()
@abstractmethod
def stop(self):
"""
Stop the event scheduler. It can be restarted
by calling start() again. Note that the scheduler
will first terminate running active tasks before
stopping, but stop() does not block to reflect that
"""
return NotImplemented()
class SyncTaskScheduler(TaskScheduler):
"""
A synchronous, thread-safe
task scheduler. Calling start() blocks
the calling thread until the scheduler
is stopped by calling stop() in another
thread. Commands are executed in their
own process
"""
def __init__(self, concurrency_limit: int | None = None):
super().__init__(concurrency_limit)
self._sem = Lock()
self._evt = Event()
self._executor = ProcessPoolExecutor(cpu_count() * 2 - 1 if concurrency_limit is None else concurrency_limit)
self._tie = 0
self._tasks: list[tuple[ScheduledTask, int]] = []
@staticmethod
def _run_command(task: ScheduledTask):
print(f"[{datetime.now().strftime('%d/%m/%Y %H:%M:%S %p')}] Running command {task.command!r}")
# TODO: os.exec(...)
def schedule(self, command: str, when: timedelta) -> ScheduledTask:
task = ScheduledTask(command, when)
with self._sem:
self._tasks.append((task, self._tie))
self._tie += 1
self._tasks.sort()
# Wakeup other thread blocked in start()
self._evt.set()
self._evt.clear()
return task
def unschedule(self, task: ScheduledTask):
with self._sem:
for t in self._tasks:
if t[0] is task:
self._tasks.remove(t)
self._tasks.sort()
break
def unschedule_all(self, task: ScheduledTask) -> int:
result = 0
with self._sem:
for t in self._tasks:
if t[0] == task:
self._tasks.remove(t)
result += 1
if result > 0:
self._tasks.sort()
return result
def start(self):
with self._sem:
if self.running:
raise RuntimeError("scheduler is already running")
self._stop = False
while True:
with self._sem:
if self._stop:
self._running = False
break
self._sem.acquire()
if not self._tasks:
# No tasks to run. To avoid busy-waiting, we
# go to sleep indefinitely on an event until
# a task is scheduled. Note how we don't hold
# the lock before calling wait, this is to avoid
# a deadlock when other threads call schedule()
self._sem.release()
self._evt.wait()
if self._sem.locked():
self._sem.release()
tasks: list[ScheduledTask] = []
with self._sem:
for task, _ in self._tasks:
if task.should_run():
tasks.append(task)
task.last = datetime.now()
self._executor.map(self._run_command, tasks)
with self._sem:
deadline = self._tasks[0][0].when.total_seconds()
# Wait again until either the closest deadline is due
# or when a new task was added (this is to make sure
# that the scheduler stays responsive to schedule()
# calls made by other threads during its runtime)
self._evt.wait(timeout=deadline)
def stop(self):
with self._sem:
self._stop = True
class AsyncTaskScheduler(TaskScheduler):
"""
An asynchronous task scheduler. Not thread
safe.
"""
def __init__(self, concurrency_limit: int | None = None):
super().__init__(concurrency_limit)
if self.concurrency_limit is None:
self.concurrency_limit = 20
self._sem = structio.Lock()
# Task concurrency limiter
self._cap = structio.Semaphore(self.concurrency_limit)
self._evt = structio.Event()
self._tie = 0
self._tasks: list[tuple[ScheduledTask, int]] = []
async def _run_command(self, task: ScheduledTask):
print(f"[{datetime.now().strftime('%d/%m/%Y %H:%M:%S %p')}] Running command {task.command!r}")
# TODO: await structio.parallel.run(...)
# Allow more commands to be executed by releasing
# a slot
await self._cap.release()
def schedule(self, command: str, when: timedelta) -> ScheduledTask:
task = ScheduledTask(command, when)
self._tasks.append((task, self._tie))
self._tie += 1
self._tasks.sort()
# Wakeup other task blocked in start()
self._evt.set()
# Events are cheap, so they don't have a
# clear method. We just create a new event
self._evt = structio.Event()
return task
def unschedule(self, task: ScheduledTask):
for t in self._tasks:
if t[0] is task:
self._tasks.remove(t)
self._tasks.sort()
break
def unschedule_all(self, task: ScheduledTask) -> int:
result = 0
for t in self._tasks:
if t[0] == task:
self._tasks.remove(t)
result += 1
if result > 0:
self._tasks.sort()
return result
async def start(self):
if self.running:
raise RuntimeError("scheduler is already running")
self._stop = False
async with structio.create_pool() as pool:
while True:
if self._stop:
self._running = False
break
if not self._tasks:
await self._evt.wait()
tasks: list[ScheduledTask] = []
for task, _ in self._tasks:
if task.should_run():
tasks.append(task)
task.last = datetime.now()
for task in tasks:
await self._cap.acquire()
pool.spawn(self._run_command, task)
deadline = self._tasks[0][0].when.total_seconds()
# Wait again until either the closest deadline is due
# or when a new task was added (this is to make sure
# that the scheduler stays responsive to schedule()
# calls made by other tasks during its runtime)
with structio.skip_after(deadline):
await self._evt.wait()
def stop(self):
self._stop = True
def main_threaded():
print("Using synchronous scheduler")
scheduler: TaskScheduler = SyncTaskScheduler()
scheduler.schedule("test", parse_delay("1s"))
scheduler.schedule("test 2", parse_delay("2s"))
scheduler_thread = Thread(target=scheduler.start)
scheduler_thread.start()
time.sleep(5)
print("Adding 3rd task")
scheduler.schedule("test 3", parse_delay("3s"))
time.sleep(10)
print("Removing third task")
scheduler.unschedule_all(ScheduledTask("test 3", parse_delay("3s")))
time.sleep(5)
print("Stopping")
scheduler.stop()
scheduler_thread.join()
print("Stopped")
async def main_async():
print("Using asynchronous scheduler")
scheduler: TaskScheduler = AsyncTaskScheduler()
scheduler.schedule("test", parse_delay("1s"))
scheduler.schedule("test 2", parse_delay("2s"))
async with structio.create_pool() as pool:
pool.spawn(scheduler.start)
await structio.sleep(5)
print("Adding 3rd task")
scheduler.schedule("test 3", parse_delay("3s"))
await structio.sleep(10)
print("Removing third task")
scheduler.unschedule_all(ScheduledTask("test 3", parse_delay("3s")))
await structio.sleep(5)
print("Stopping")
scheduler.stop()
print("Stopped")
structio.run(main_async)
main_threaded()

10
tests/httpcore_test.py Normal file
View File

@ -0,0 +1,10 @@
import httpcore
import structio
async def main():
pool = httpcore.AsyncConnectionPool()
print(await pool.request("GET", "http://example.com"))
structio.run(main)