Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
grader支持获取input
  • Loading branch information
Mr-Python-in-China committed Feb 17, 2025
commit a1313749aca091048d21ae3b8dd16fa5b9d7ac7d
86 changes: 44 additions & 42 deletions cyaron/compare.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import absolute_import, print_function

import multiprocessing
import os
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from io import open
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, cast

from cyaron.consts import *
from cyaron.graders import CYaRonGraders, GraderType
from cyaron.graders import CYaRonGraders, GraderType3
from cyaron.utils import *

from . import log
Expand All @@ -27,14 +26,16 @@ def __str__(self):
return "In program: '{}'. {}".format(self.name, self.mismatch)


PrgoramType = Optional[Union[str, Tuple[str, ...], List[str]]]
PrgoramType = Union[str, Tuple[str, ...], List[str]]


class Compare:

@staticmethod
def __compare_two(name, content, std, grader):
result, info = CYaRonGraders.invoke(grader, content, std)
def __compare_two(name: PrgoramType, content: str, std: str,
input_content: str, grader: Union[str, GraderType3]):
result, info = CYaRonGraders.invoke(grader, content, std,
input_content)
status = "Correct" if result else "!!!INCORRECT!!!"
info = info if info is not None else ""
log.debug("{}: {} {}".format(name, status, info))
Expand Down Expand Up @@ -77,7 +78,7 @@ def output(cls, *files, **kwargs):
("stop_on_incorrect", None),
),
)
std = kwargs["std"]
std: IO = kwargs["std"]
grader = kwargs["grader"]
max_workers = kwargs["max_workers"]
job_pool = kwargs["job_pool"]
Expand All @@ -101,13 +102,18 @@ def get_std():
return cls.__process_output_file(std)[1]

if job_pool is not None:
std = job_pool.submit(get_std).result()
std_answer = job_pool.submit(get_std).result()
else:
std = get_std()
std_answer = get_std()

with open(std.input_filename, "r", newline="\n",
encoding="utf-8") as input_file:
input_text = input_file.read()

def do(file):
(file_name, content) = cls.__process_output_file(file)
cls.__compare_two(file_name, content, std, grader)
cls.__compare_two(file_name, content, std_answer, input_text,
grader)

if job_pool is not None:
job_pool.map(do, files)
Expand All @@ -121,8 +127,8 @@ def program(cls,
std: Optional[Union[str, IO]] = None,
std_program: Optional[Union[str, Tuple[str, ...],
List[str]]] = None,
grader: Union[str, GraderType] = DEFAULT_GRADER,
max_workers: int = -1,
grader: Union[str, GraderType3] = DEFAULT_GRADER,
max_workers: Optional[int] = -1,
job_pool: Optional[ThreadPoolExecutor] = None,
stop_on_incorrect=None):
"""
Expand Down Expand Up @@ -182,7 +188,7 @@ def get_std_from_std_program():
elif std is not None:

def get_std_from_std_file():
return cls.__process_output_file(std)[1]
return cls.__process_output_file(cast(Union[str, IO], std))[1]

if job_pool is not None:
std = job_pool.submit(get_std_from_std_file).result()
Expand All @@ -197,33 +203,29 @@ def get_std_from_std_file():
"r",
newline="\n",
encoding="utf-8") as input_file:

def do(program_name):
timeout = None
if (list_like(program_name) and len(program_name) == 2
and int_like(program_name[-1])):
program_name, timeout = program_name
if timeout is None:
content = subprocess.check_output(
program_name,
shell=(not list_like(program_name)),
stdin=input_file,
universal_newlines=True,
encoding="utf-8",
)
else:
content = subprocess.check_output(
program_name,
shell=(not list_like(program_name)),
stdin=input_file,
universal_newlines=True,
timeout=timeout,
encoding="utf-8",
)
cls.__compare_two(program_name, content, std, grader)

if job_pool is not None:
job_pool.map(do, programs)
input_text = input_file.read()

def do(program_name: Union[PrgoramType, Tuple[PrgoramType, float]]):
timeout = None
if isinstance(program_name, tuple) and len(program_name) == 2 and (
isinstance(program_name[1], float)
or isinstance(program_name[1], int)):
program_name, timeout = cast(Tuple[PrgoramType, float],
program_name)
else:
for program in programs:
do(program)
program_name = cast(PrgoramType, program_name)
content = subprocess.check_output(
list(program_name)
if isinstance(program_name, tuple) else program_name,
shell=(not list_like(program_name)),
input=input_text,
universal_newlines=True,
encoding="utf-8",
timeout=timeout)
cls.__compare_two(program_name, content, std, input_text, grader)

if job_pool is not None:
job_pool.map(do, programs)
else:
for program in programs:
do(program)
2 changes: 1 addition & 1 deletion cyaron/graders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .graderregistry import CYaRonGraders, GraderType
from .graderregistry import CYaRonGraders, GraderType2, GraderType3

from .fulltext import fulltext
from .noipstyle import noipstyle
42 changes: 32 additions & 10 deletions cyaron/graders/graderregistry.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,51 @@
from typing import Callable, Tuple, Dict, Union
from typing import Callable, Tuple, Dict, Union, Any

__all__ = ['CYaRonGraders', 'GraderType']
__all__ = ['CYaRonGraders', 'GraderType2', 'GraderType3']

GraderType = Callable[[str, str], Tuple[bool, Union[str, None]]]
GraderType2 = Callable[[str, str], Tuple[bool, Any]]
GraderType3 = Callable[[str, str, str], Tuple[bool, Any]]


class GraderRegistry:
"""A registry for grader functions."""
_registry: Dict[str, GraderType] = {}
_registry: Dict[str, GraderType3] = {}

def grader2(self, name: str):
"""
This decorator registers a grader function under a specific name in the registry.

The function being decorated should accept exactly two parameters (excluding
the content input).
"""

def wrapper(func: GraderType2):
self._registry[name] = lambda content, std, _: func(content, std)
return func

return wrapper

grader = grader2

def grader(self, name: str):
"""A decorator to register a grader function."""
def grader3(self, name: str):
"""
This decorator registers a grader function under a specific name in the registry.

The function being decorated should accept exactly three parameters.
"""

def wrapper(func: GraderType):
def wrapper(func: GraderType3):
self._registry[name] = func
return func

return wrapper

def invoke(self, grader: Union[str, GraderType], content: str, std: str):
def invoke(self, grader: Union[str, GraderType3], content: str, std: str,
input_content: str):
"""Invoke a grader function by name or function object."""
if isinstance(grader, str):
return self._registry[grader](content, std)
return self._registry[grader](content, std, input_content)
else:
return grader(content, std)
return grader(content, std, input_content)

def check(self, name):
"""Check if a grader is registered."""
Expand Down
56 changes: 40 additions & 16 deletions cyaron/tests/compare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def test_file_input_success(self):
grader="NOIPStyle")

def test_file_input_fail(self):
with open("correct.py", "w") as f:
f.write("print(input())")
with open("std.py", "w") as f:
with open("incorrect.py", "w") as f:
f.write("print(input()+'154')")
with open("std.py", "w") as f:
f.write("print(input())")
io = IO()
io.input_writeln("233")
try:
with captured_output():
Compare.program((sys.executable, "correct.py"),
Compare.program((sys.executable, "incorrect.py"),
std_program=(sys.executable, "std.py"),
input=io,
grader="NOIPStyle")
Expand Down Expand Up @@ -178,10 +178,11 @@ def test_timeout(self):
else:
self.assertTrue(False)

def test_custom_grader_by_name(self):
def test_custom_grader2_by_name(self):
self.assertEqual(CYaRonGraders.grader, CYaRonGraders.grader2)

@CYaRonGraders.grader("CustomTestGrader")
def custom_test_grader(content: str, std: str):
@CYaRonGraders.grader("CustomTestGrader2")
def custom_test_grader2(content: str, std: str):
if content == '1\n' and std == '2\n':
return True, None
return False, "CustomTestGrader failed"
Expand All @@ -192,13 +193,38 @@ def custom_test_grader(content: str, std: str):
Compare.program("echo 1",
std=io,
input=IO(),
grader="CustomTestGrader")
grader="CustomTestGrader2")

try:
Compare.program("echo 2",
std=io,
input=IO(),
grader="CustomTestGrader")
grader="CustomTestGrader2")
except CompareMismatch as e:
self.assertEqual(e.name, 'echo 2')
self.assertEqual(e.mismatch, "CustomTestGrader failed")
else:
self.fail("Should raise CompareMismatch")

def test_custom_grader3_by_name(self):

@CYaRonGraders.grader3("CustomTestGrader3")
def custom_test_grader3(content: str, std: str, input_content: str):
if input_content == '0\n' and content == '1\n' and std == '2\n':
return True, None
return False, "CustomTestGrader failed"

io = IO()
io.input_writeln("0")
io.output_writeln("2")

Compare.program("echo 1", std=io, input=io, grader="CustomTestGrader3")

try:
Compare.program("echo 2",
std=io,
input=io,
grader='CustomTestGrader3')
except CompareMismatch as e:
self.assertEqual(e.name, 'echo 2')
self.assertEqual(e.mismatch, "CustomTestGrader failed")
Expand All @@ -207,23 +233,21 @@ def custom_test_grader(content: str, std: str):

def test_custom_grader_by_function(self):

def custom_test_grader(content: str, std: str):
if content == '1\n' and std == '2\n':
def custom_test_grader(content: str, std: str, input_content: str):
if input_content == '0\n' and content == '1\n' and std == '2\n':
return True, None
return False, "CustomTestGrader failed"

io = IO()
io.input_writeln("0")
io.output_writeln("2")

Compare.program("echo 1",
std=io,
input=IO(),
grader=custom_test_grader)
Compare.program("echo 1", std=io, input=io, grader=custom_test_grader)

try:
Compare.program("echo 2",
std=io,
input=IO(),
input=io,
grader=custom_test_grader)
except CompareMismatch as e:
self.assertEqual(e.name, 'echo 2')
Expand Down
Loading