From bf13ca03a5bb30e2e687c0efc9e93f8e8c73ce5e Mon Sep 17 00:00:00 2001 From: sobolevn Date: Sun, 8 Oct 2023 23:21:48 +0300 Subject: [PATCH 1/4] gh-110525: Add CAPI tests for `set` and `frozenset` objects --- Lib/test/test_capi/test_set.py | 175 +++++++++++++++++++++++++++++++++ Modules/Setup.stdlib.in | 2 +- Modules/_testcapi/parts.h | 1 + Modules/_testcapi/set.c | 156 +++++++++++++++++++++++++++++ Modules/_testcapimodule.c | 3 + 5 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 Lib/test/test_capi/test_set.py create mode 100644 Modules/_testcapi/set.c diff --git a/Lib/test/test_capi/test_set.py b/Lib/test/test_capi/test_set.py new file mode 100644 index 00000000000000..04a975358240dd --- /dev/null +++ b/Lib/test/test_capi/test_set.py @@ -0,0 +1,175 @@ +import unittest + +from test.support import import_helper + +# Skip this test if the _testcapi module isn't available. +_testcapi = import_helper.import_module('_testcapi') + +class set_child(set): + pass + +class frozenset_child(frozenset): + pass + + +class TestSetCAPI(unittest.TestCase): + def assertImmutable(self, action, *args): + self.assertRaises(SystemError, action, frozenset(), *args) + self.assertRaises(SystemError, action, frozenset({1}), *args) + self.assertRaises(SystemError, action, frozenset_child(), *args) + self.assertRaises(SystemError, action, frozenset_child({1}), *args) + + def test_set_check(self): + check = _testcapi.set_check + self.assertTrue(check(set())) + self.assertTrue(check({1, 2})) + self.assertFalse(check(frozenset())) + self.assertTrue(check(set_child())) + self.assertFalse(check(frozenset_child())) + self.assertFalse(check(object())) + + def test_set_check_exact(self): + check = _testcapi.set_checkexact + self.assertTrue(check(set())) + self.assertTrue(check({1, 2})) + self.assertFalse(check(frozenset())) + self.assertFalse(check(set_child())) + self.assertFalse(check(frozenset_child())) + self.assertFalse(check(object())) + + def test_frozenset_check(self): + check = _testcapi.frozenset_check + self.assertFalse(check(set())) + self.assertTrue(check(frozenset())) + self.assertTrue(check(frozenset({1, 2}))) + self.assertFalse(check(set_child())) + self.assertTrue(check(frozenset_child())) + self.assertFalse(check(object())) + + def test_frozenset_check_exact(self): + check = _testcapi.frozenset_checkexact + self.assertFalse(check(set())) + self.assertTrue(check(frozenset())) + self.assertTrue(check(frozenset({1, 2}))) + self.assertFalse(check(set_child())) + self.assertFalse(check(frozenset_child())) + self.assertFalse(check(object())) + + def test_anyset_check(self): + check = _testcapi.anyset_check + self.assertTrue(check(set())) + self.assertTrue(check({1, 2})) + self.assertTrue(check(frozenset())) + self.assertTrue(check(frozenset({1, 2}))) + self.assertTrue(check(set_child())) + self.assertTrue(check(frozenset_child())) + self.assertFalse(check(object())) + + def test_anyset_check_exact(self): + check = _testcapi.anyset_checkexact + self.assertTrue(check(set())) + self.assertTrue(check({1, 2})) + self.assertTrue(check(frozenset())) + self.assertTrue(check(frozenset({1, 2}))) + self.assertFalse(check(set_child())) + self.assertFalse(check(frozenset_child())) + self.assertFalse(check(object())) + + def test_set_new(self): + new = _testcapi.set_new + self.assertEqual(new().__class__, set) + self.assertEqual(new(), set()) + self.assertEqual(new((1, 1, 2)), {1, 2}) + with self.assertRaisesRegex(TypeError, 'object is not iterable'): + new(object()) + with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"): + new((1, {})) + + def test_frozenset_new(self): + new = _testcapi.frozenset_new + self.assertEqual(new().__class__, frozenset) + self.assertEqual(new(), frozenset()) + self.assertEqual(new((1, 1, 2)), frozenset({1, 2})) + with self.assertRaisesRegex(TypeError, 'object is not iterable'): + new(object()) + with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"): + new((1, {})) + + def test_set_size(self): + l = _testcapi.set_size + self.assertEqual(l(set()), 0) + self.assertEqual(l(frozenset()), 0) + self.assertEqual(l({1, 1, 2}), 2) + self.assertEqual(l(frozenset({1, 1, 2})), 2) + self.assertEqual(l(set_child((1, 2, 3))), 3) + self.assertEqual(l(frozenset_child((1, 2, 3))), 3) + with self.assertRaises(SystemError): + l([]) + + def test_set_get_size(self): + l = _testcapi.set_get_size + self.assertEqual(l(set()), 0) + self.assertEqual(l(frozenset()), 0) + self.assertEqual(l({1, 1, 2}), 2) + self.assertEqual(l(frozenset({1, 1, 2})), 2) + self.assertEqual(l(set_child((1, 2, 3))), 3) + self.assertEqual(l(frozenset_child((1, 2, 3))), 3) + # CRASHES: l([]) + + def test_set_contains(self): + c = _testcapi.set_contains + for cls in (set, frozenset, set_child, frozenset_child): + with self.subTest(cls=cls): + instance = cls((1, 2)) + self.assertTrue(c(instance, 1)) + self.assertFalse(c(instance, 'missing')) + + def test_add(self): + add = _testcapi.set_add + for cls in (set, set_child): + with self.subTest(cls=cls): + instance = cls((1, 2)) + self.assertEqual(add(instance, 1), 0) + self.assertEqual(instance, {1, 2}) + self.assertEqual(add(instance, 3), 0) + self.assertEqual(instance, {1, 2, 3}) + self.assertImmutable(add, 1) + + def test_discard(self): + discard = _testcapi.set_discard + for cls in (set, set_child): + with self.subTest(cls=cls): + instance = cls((1, 2)) + self.assertEqual(discard(instance, 3), 0) + self.assertEqual(instance, {1, 2}) + self.assertEqual(discard(instance, 1), 1) + self.assertEqual(instance, {2}) + self.assertEqual(discard(instance, 2), 1) + self.assertEqual(instance, set()) + # Discarding from empty set works + self.assertEqual(discard(instance, 2), 0) + self.assertEqual(instance, set()) + self.assertImmutable(discard, 1) + + def test_pop(self): + pop = _testcapi.set_pop + orig = (1, 2) + for cls in (set, set_child): + with self.subTest(cls=cls): + instance = cls(orig) + self.assertIn(pop(instance), orig) + self.assertEqual(len(instance), 1) + self.assertIn(pop(instance), orig) + self.assertEqual(len(instance), 0) + with self.assertRaises(KeyError): + pop(instance) + self.assertImmutable(pop) + + def test_clear(self): + clear = _testcapi.set_clear + for cls in (set, set_child): + with self.subTest(cls=cls): + instance = cls((1, 2)) + self.assertEqual(clear(instance), 0) + self.assertEqual(instance, set()) + self.assertImmutable(clear) diff --git a/Modules/Setup.stdlib.in b/Modules/Setup.stdlib.in index 7b3216a50bb284..8428142a852529 100644 --- a/Modules/Setup.stdlib.in +++ b/Modules/Setup.stdlib.in @@ -159,7 +159,7 @@ @MODULE__XXTESTFUZZ_TRUE@_xxtestfuzz _xxtestfuzz/_xxtestfuzz.c _xxtestfuzz/fuzzer.c @MODULE__TESTBUFFER_TRUE@_testbuffer _testbuffer.c @MODULE__TESTINTERNALCAPI_TRUE@_testinternalcapi _testinternalcapi.c _testinternalcapi/test_lock.c _testinternalcapi/pytime.c -@MODULE__TESTCAPI_TRUE@_testcapi _testcapimodule.c _testcapi/vectorcall.c _testcapi/vectorcall_limited.c _testcapi/heaptype.c _testcapi/abstract.c _testcapi/unicode.c _testcapi/dict.c _testcapi/getargs.c _testcapi/datetime.c _testcapi/docstring.c _testcapi/mem.c _testcapi/watchers.c _testcapi/long.c _testcapi/float.c _testcapi/structmember.c _testcapi/exceptions.c _testcapi/code.c _testcapi/buffer.c _testcapi/pyatomic.c _testcapi/pyos.c _testcapi/immortal.c _testcapi/heaptype_relative.c _testcapi/gc.c +@MODULE__TESTCAPI_TRUE@_testcapi _testcapimodule.c _testcapi/vectorcall.c _testcapi/vectorcall_limited.c _testcapi/heaptype.c _testcapi/abstract.c _testcapi/unicode.c _testcapi/dict.c _testcapi/set.c _testcapi/getargs.c _testcapi/datetime.c _testcapi/docstring.c _testcapi/mem.c _testcapi/watchers.c _testcapi/long.c _testcapi/float.c _testcapi/structmember.c _testcapi/exceptions.c _testcapi/code.c _testcapi/buffer.c _testcapi/pyatomic.c _testcapi/pyos.c _testcapi/immortal.c _testcapi/heaptype_relative.c _testcapi/gc.c @MODULE__TESTCLINIC_TRUE@_testclinic _testclinic.c @MODULE__TESTCLINIC_LIMITED_TRUE@_testclinic_limited _testclinic_limited.c diff --git a/Modules/_testcapi/parts.h b/Modules/_testcapi/parts.h index 24abe54814e611..acdba86504f58e 100644 --- a/Modules/_testcapi/parts.h +++ b/Modules/_testcapi/parts.h @@ -34,6 +34,7 @@ int _PyTestCapi_Init_Watchers(PyObject *module); int _PyTestCapi_Init_Long(PyObject *module); int _PyTestCapi_Init_Float(PyObject *module); int _PyTestCapi_Init_Dict(PyObject *module); +int _PyTestCapi_Init_Set(PyObject *module); int _PyTestCapi_Init_Structmember(PyObject *module); int _PyTestCapi_Init_Exceptions(PyObject *module); int _PyTestCapi_Init_Code(PyObject *module); diff --git a/Modules/_testcapi/set.c b/Modules/_testcapi/set.c new file mode 100644 index 00000000000000..d9a7a2ba87b336 --- /dev/null +++ b/Modules/_testcapi/set.c @@ -0,0 +1,156 @@ +#include // ptrdiff_t + +#include "parts.h" +#include "util.h" + +static PyObject * +set_check(PyObject *self, PyObject *obj) +{ + RETURN_INT(PySet_Check(obj)); +} + +static PyObject * +set_checkexact(PyObject *self, PyObject *obj) +{ + RETURN_INT(PySet_CheckExact(obj)); +} + +static PyObject * +frozenset_check(PyObject *self, PyObject *obj) +{ + RETURN_INT(PyFrozenSet_Check(obj)); +} + +static PyObject * +frozenset_checkexact(PyObject *self, PyObject *obj) +{ + RETURN_INT(PyFrozenSet_CheckExact(obj)); +} + +static PyObject * +anyset_check(PyObject *self, PyObject *obj) +{ + RETURN_INT(PyAnySet_Check(obj)); +} + +static PyObject * +anyset_checkexact(PyObject *self, PyObject *obj) +{ + RETURN_INT(PyAnySet_CheckExact(obj)); +} + +static PyObject * +set_new(PyObject *self, PyObject *args) +{ + PyObject *iterable = NULL; + if (!PyArg_ParseTuple(args, "|O", &iterable)) { + return NULL; + } + return PySet_New(iterable); +} + +static PyObject * +frozenset_new(PyObject *self, PyObject *args) +{ + PyObject *iterable = NULL; + if (!PyArg_ParseTuple(args, "|O", &iterable)) { + return NULL; + } + return PyFrozenSet_New(iterable); +} + +static PyObject * +set_size(PyObject *self, PyObject *obj) +{ + NULLABLE(obj); + RETURN_SIZE(PySet_Size(obj)); +} + +static PyObject * +set_get_size(PyObject *self, PyObject *obj) +{ + NULLABLE(obj); + RETURN_SIZE(PySet_GET_SIZE(obj)); +} + +static PyObject * +set_contains(PyObject *self, PyObject *args) +{ + PyObject *obj, *item; + if (!PyArg_ParseTuple(args, "OO", &obj, &item)) { + return NULL; + } + NULLABLE(obj); + NULLABLE(item); + RETURN_INT(PySet_Contains(obj, item)); +} + +static PyObject * +set_add(PyObject *self, PyObject *args) +{ + PyObject *obj, *item; + if (!PyArg_ParseTuple(args, "OO", &obj, &item)) { + return NULL; + } + NULLABLE(obj); + NULLABLE(item); + RETURN_INT(PySet_Add(obj, item)); +} + +static PyObject * +set_discard(PyObject *self, PyObject *args) +{ + PyObject *obj, *item; + if (!PyArg_ParseTuple(args, "OO", &obj, &item)) { + return NULL; + } + NULLABLE(obj); + NULLABLE(item); + RETURN_INT(PySet_Discard(obj, item)); +} + +static PyObject * +set_pop(PyObject *self, PyObject *obj) +{ + NULLABLE(obj); + return PySet_Pop(obj); +} + +static PyObject * +set_clear(PyObject *self, PyObject *obj) +{ + NULLABLE(obj); + RETURN_INT(PySet_Clear(obj)); +} + +static PyMethodDef test_methods[] = { + {"set_check", set_check, METH_O}, + {"set_checkexact", set_checkexact, METH_O}, + {"frozenset_check", frozenset_check, METH_O}, + {"frozenset_checkexact", frozenset_checkexact, METH_O}, + {"anyset_check", anyset_check, METH_O}, + {"anyset_checkexact", anyset_checkexact, METH_O}, + + {"set_new", set_new, METH_VARARGS}, + {"frozenset_new", frozenset_new, METH_VARARGS}, + + {"set_size", set_size, METH_O}, + {"set_get_size", set_get_size, METH_O}, + {"set_contains", set_contains, METH_VARARGS}, + {"set_add", set_add, METH_VARARGS}, + {"set_discard", set_discard, METH_VARARGS}, + {"set_pop", set_pop, METH_O}, + {"set_clear", set_clear, METH_O}, + + {NULL}, +}; + +int +_PyTestCapi_Init_Set(PyObject *m) +{ + if (PyModule_AddFunctions(m, test_methods) < 0) { + return -1; + } + + return 0; +} diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index a46d986c18ecd4..ce3d0b1b1b005c 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -3981,6 +3981,9 @@ PyInit__testcapi(void) if (_PyTestCapi_Init_Dict(m) < 0) { return NULL; } + if (_PyTestCapi_Init_Set(m) < 0) { + return NULL; + } if (_PyTestCapi_Init_Structmember(m) < 0) { return NULL; } From 314a7b3a93dab007384dc4e187d3bdf372942cdb Mon Sep 17 00:00:00 2001 From: sobolevn Date: Sun, 8 Oct 2023 23:32:58 +0300 Subject: [PATCH 2/4] Fix Win builds --- PCbuild/_testcapi.vcxproj | 1 + 1 file changed, 1 insertion(+) diff --git a/PCbuild/_testcapi.vcxproj b/PCbuild/_testcapi.vcxproj index 0a02929db438b8..0f33c5a76ade9d 100644 --- a/PCbuild/_testcapi.vcxproj +++ b/PCbuild/_testcapi.vcxproj @@ -102,6 +102,7 @@ + From 9feda284557142ab60d2564b5de0cba042a57025 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Mon, 9 Oct 2023 09:51:18 +0300 Subject: [PATCH 3/4] Address review --- Lib/test/test_capi/test_set.py | 144 +++++++++++++++++++++------------ Modules/_testcapi/set.c | 6 ++ 2 files changed, 97 insertions(+), 53 deletions(-) diff --git a/Lib/test/test_capi/test_set.py b/Lib/test/test_capi/test_set.py index 04a975358240dd..dbf86eddcb98b5 100644 --- a/Lib/test/test_capi/test_set.py +++ b/Lib/test/test_capi/test_set.py @@ -5,10 +5,10 @@ # Skip this test if the _testcapi module isn't available. _testcapi = import_helper.import_module('_testcapi') -class set_child(set): +class set_subclass(set): pass -class frozenset_child(frozenset): +class frozenset_subclass(frozenset): pass @@ -16,44 +16,48 @@ class TestSetCAPI(unittest.TestCase): def assertImmutable(self, action, *args): self.assertRaises(SystemError, action, frozenset(), *args) self.assertRaises(SystemError, action, frozenset({1}), *args) - self.assertRaises(SystemError, action, frozenset_child(), *args) - self.assertRaises(SystemError, action, frozenset_child({1}), *args) + self.assertRaises(SystemError, action, frozenset_subclass(), *args) + self.assertRaises(SystemError, action, frozenset_subclass({1}), *args) def test_set_check(self): check = _testcapi.set_check self.assertTrue(check(set())) self.assertTrue(check({1, 2})) self.assertFalse(check(frozenset())) - self.assertTrue(check(set_child())) - self.assertFalse(check(frozenset_child())) + self.assertTrue(check(set_subclass())) + self.assertFalse(check(frozenset_subclass())) self.assertFalse(check(object())) + # CRASHES: check(NULL) def test_set_check_exact(self): check = _testcapi.set_checkexact self.assertTrue(check(set())) self.assertTrue(check({1, 2})) self.assertFalse(check(frozenset())) - self.assertFalse(check(set_child())) - self.assertFalse(check(frozenset_child())) + self.assertFalse(check(set_subclass())) + self.assertFalse(check(frozenset_subclass())) self.assertFalse(check(object())) + # CRASHES: check(NULL) def test_frozenset_check(self): check = _testcapi.frozenset_check self.assertFalse(check(set())) self.assertTrue(check(frozenset())) self.assertTrue(check(frozenset({1, 2}))) - self.assertFalse(check(set_child())) - self.assertTrue(check(frozenset_child())) + self.assertFalse(check(set_subclass())) + self.assertTrue(check(frozenset_subclass())) self.assertFalse(check(object())) + # CRASHES: check(NULL) def test_frozenset_check_exact(self): check = _testcapi.frozenset_checkexact self.assertFalse(check(set())) self.assertTrue(check(frozenset())) self.assertTrue(check(frozenset({1, 2}))) - self.assertFalse(check(set_child())) - self.assertFalse(check(frozenset_child())) + self.assertFalse(check(set_subclass())) + self.assertFalse(check(frozenset_subclass())) self.assertFalse(check(object())) + # CRASHES: check(NULL) def test_anyset_check(self): check = _testcapi.anyset_check @@ -61,9 +65,10 @@ def test_anyset_check(self): self.assertTrue(check({1, 2})) self.assertTrue(check(frozenset())) self.assertTrue(check(frozenset({1, 2}))) - self.assertTrue(check(set_child())) - self.assertTrue(check(frozenset_child())) + self.assertTrue(check(set_subclass())) + self.assertTrue(check(frozenset_subclass())) self.assertFalse(check(object())) + # CRASHES: check(NULL) def test_anyset_check_exact(self): check = _testcapi.anyset_checkexact @@ -71,73 +76,92 @@ def test_anyset_check_exact(self): self.assertTrue(check({1, 2})) self.assertTrue(check(frozenset())) self.assertTrue(check(frozenset({1, 2}))) - self.assertFalse(check(set_child())) - self.assertFalse(check(frozenset_child())) + self.assertFalse(check(set_subclass())) + self.assertFalse(check(frozenset_subclass())) self.assertFalse(check(object())) + # CRASHES: check(NULL) def test_set_new(self): - new = _testcapi.set_new - self.assertEqual(new().__class__, set) - self.assertEqual(new(), set()) - self.assertEqual(new((1, 1, 2)), {1, 2}) + set_new = _testcapi.set_new + self.assertEqual(set_new().__class__, set) + self.assertEqual(set_new(), set()) + self.assertEqual(set_new((1, 1, 2)), {1, 2}) with self.assertRaisesRegex(TypeError, 'object is not iterable'): - new(object()) + set_new(object()) + with self.assertRaisesRegex(TypeError, 'object is not iterable'): + set_new(None) with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"): - new((1, {})) + set_new((1, {})) def test_frozenset_new(self): - new = _testcapi.frozenset_new - self.assertEqual(new().__class__, frozenset) - self.assertEqual(new(), frozenset()) - self.assertEqual(new((1, 1, 2)), frozenset({1, 2})) + frozenset_new = _testcapi.frozenset_new + self.assertEqual(frozenset_new().__class__, frozenset) + self.assertEqual(frozenset_new(), frozenset()) + self.assertEqual(frozenset_new((1, 1, 2)), frozenset({1, 2})) + with self.assertRaisesRegex(TypeError, 'object is not iterable'): + frozenset_new(object()) with self.assertRaisesRegex(TypeError, 'object is not iterable'): - new(object()) + frozenset_new(None) with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"): - new((1, {})) + frozenset_new((1, {})) def test_set_size(self): - l = _testcapi.set_size - self.assertEqual(l(set()), 0) - self.assertEqual(l(frozenset()), 0) - self.assertEqual(l({1, 1, 2}), 2) - self.assertEqual(l(frozenset({1, 1, 2})), 2) - self.assertEqual(l(set_child((1, 2, 3))), 3) - self.assertEqual(l(frozenset_child((1, 2, 3))), 3) + get_size = _testcapi.set_size + self.assertEqual(get_size(set()), 0) + self.assertEqual(get_size(frozenset()), 0) + self.assertEqual(get_size({1, 1, 2}), 2) + self.assertEqual(get_size(frozenset({1, 1, 2})), 2) + self.assertEqual(get_size(set_subclass((1, 2, 3))), 3) + self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3) with self.assertRaises(SystemError): - l([]) + get_size([]) + # CRASHES: get_size(NULL) def test_set_get_size(self): - l = _testcapi.set_get_size - self.assertEqual(l(set()), 0) - self.assertEqual(l(frozenset()), 0) - self.assertEqual(l({1, 1, 2}), 2) - self.assertEqual(l(frozenset({1, 1, 2})), 2) - self.assertEqual(l(set_child((1, 2, 3))), 3) - self.assertEqual(l(frozenset_child((1, 2, 3))), 3) - # CRASHES: l([]) + get_size = _testcapi.set_get_size + self.assertEqual(get_size(set()), 0) + self.assertEqual(get_size(frozenset()), 0) + self.assertEqual(get_size({1, 1, 2}), 2) + self.assertEqual(get_size(frozenset({1, 1, 2})), 2) + self.assertEqual(get_size(set_subclass((1, 2, 3))), 3) + self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3) + # CRASHES: get_size(NULL) + # CRASHES: get_size(object()) def test_set_contains(self): - c = _testcapi.set_contains - for cls in (set, frozenset, set_child, frozenset_child): + contains = _testcapi.set_contains + for cls in (set, frozenset, set_subclass, frozenset_subclass): with self.subTest(cls=cls): instance = cls((1, 2)) - self.assertTrue(c(instance, 1)) - self.assertFalse(c(instance, 'missing')) + self.assertTrue(contains(instance, 1)) + self.assertFalse(contains(instance, 'missing')) + with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"): + contains(instance, []) + # CRASHES: contains(instance, NULL) + # CRASHES: contains(NULL, object()) + # CRASHES: contains(NULL, NULL) def test_add(self): add = _testcapi.set_add - for cls in (set, set_child): + for cls in (set, set_subclass): with self.subTest(cls=cls): instance = cls((1, 2)) self.assertEqual(add(instance, 1), 0) self.assertEqual(instance, {1, 2}) self.assertEqual(add(instance, 3), 0) self.assertEqual(instance, {1, 2, 3}) + with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"): + add(instance, []) + # CRASHES: add(NULL, object()) + # CRASHES: add(instance, NULL) + # CRASHES: add(NULL, NULL) + with self.assertRaises(SystemError): + add(object(), 1) self.assertImmutable(add, 1) def test_discard(self): discard = _testcapi.set_discard - for cls in (set, set_child): + for cls in (set, set_subclass): with self.subTest(cls=cls): instance = cls((1, 2)) self.assertEqual(discard(instance, 3), 0) @@ -146,15 +170,21 @@ def test_discard(self): self.assertEqual(instance, {2}) self.assertEqual(discard(instance, 2), 1) self.assertEqual(instance, set()) - # Discarding from empty set works self.assertEqual(discard(instance, 2), 0) self.assertEqual(instance, set()) + with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"): + discard(instance, []) + # CRASHES: discard(NULL, object()) + # CRASHES: discard(instance, NULL) + # CRASHES: discard(NULL, NULL) + with self.assertRaises(SystemError): + discard(object(), 1) self.assertImmutable(discard, 1) def test_pop(self): pop = _testcapi.set_pop orig = (1, 2) - for cls in (set, set_child): + for cls in (set, set_subclass): with self.subTest(cls=cls): instance = cls(orig) self.assertIn(pop(instance), orig) @@ -163,13 +193,21 @@ def test_pop(self): self.assertEqual(len(instance), 0) with self.assertRaises(KeyError): pop(instance) + # CRASHES: pop(NULL) + with self.assertRaises(SystemError): + pop(object()) self.assertImmutable(pop) def test_clear(self): clear = _testcapi.set_clear - for cls in (set, set_child): + for cls in (set, set_subclass): with self.subTest(cls=cls): instance = cls((1, 2)) self.assertEqual(clear(instance), 0) self.assertEqual(instance, set()) + self.assertEqual(clear(instance), 0) + self.assertEqual(instance, set()) + # CRASHES: clear(NULL) + with self.assertRaises(SystemError): + clear(object()) self.assertImmutable(clear) diff --git a/Modules/_testcapi/set.c b/Modules/_testcapi/set.c index d9a7a2ba87b336..f68a1859698132 100644 --- a/Modules/_testcapi/set.c +++ b/Modules/_testcapi/set.c @@ -6,36 +6,42 @@ static PyObject * set_check(PyObject *self, PyObject *obj) { + NULLABLE(obj); RETURN_INT(PySet_Check(obj)); } static PyObject * set_checkexact(PyObject *self, PyObject *obj) { + NULLABLE(obj); RETURN_INT(PySet_CheckExact(obj)); } static PyObject * frozenset_check(PyObject *self, PyObject *obj) { + NULLABLE(obj); RETURN_INT(PyFrozenSet_Check(obj)); } static PyObject * frozenset_checkexact(PyObject *self, PyObject *obj) { + NULLABLE(obj); RETURN_INT(PyFrozenSet_CheckExact(obj)); } static PyObject * anyset_check(PyObject *self, PyObject *obj) { + NULLABLE(obj); RETURN_INT(PyAnySet_Check(obj)); } static PyObject * anyset_checkexact(PyObject *self, PyObject *obj) { + NULLABLE(obj); RETURN_INT(PyAnySet_CheckExact(obj)); } From 9b8f7a59a29d35c9547a019a5c37ee36542f8547 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Mon, 9 Oct 2023 11:23:44 +0300 Subject: [PATCH 4/4] Address review --- Lib/test/test_capi/test_set.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/Lib/test/test_capi/test_set.py b/Lib/test/test_capi/test_set.py index dbf86eddcb98b5..e9165e7e6806dd 100644 --- a/Lib/test/test_capi/test_set.py +++ b/Lib/test/test_capi/test_set.py @@ -86,10 +86,11 @@ def test_set_new(self): self.assertEqual(set_new().__class__, set) self.assertEqual(set_new(), set()) self.assertEqual(set_new((1, 1, 2)), {1, 2}) + self.assertEqual(set_new([1, 1, 2]), {1, 2}) with self.assertRaisesRegex(TypeError, 'object is not iterable'): set_new(object()) with self.assertRaisesRegex(TypeError, 'object is not iterable'): - set_new(None) + set_new(1) with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"): set_new((1, {})) @@ -98,10 +99,11 @@ def test_frozenset_new(self): self.assertEqual(frozenset_new().__class__, frozenset) self.assertEqual(frozenset_new(), frozenset()) self.assertEqual(frozenset_new((1, 1, 2)), frozenset({1, 2})) + self.assertEqual(frozenset_new([1, 1, 2]), frozenset({1, 2})) with self.assertRaisesRegex(TypeError, 'object is not iterable'): frozenset_new(object()) with self.assertRaisesRegex(TypeError, 'object is not iterable'): - frozenset_new(None) + frozenset_new(1) with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"): frozenset_new((1, {})) @@ -114,7 +116,7 @@ def test_set_size(self): self.assertEqual(get_size(set_subclass((1, 2, 3))), 3) self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3) with self.assertRaises(SystemError): - get_size([]) + get_size(object()) # CRASHES: get_size(NULL) def test_set_get_size(self): @@ -137,9 +139,9 @@ def test_set_contains(self): self.assertFalse(contains(instance, 'missing')) with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"): contains(instance, []) - # CRASHES: contains(instance, NULL) - # CRASHES: contains(NULL, object()) - # CRASHES: contains(NULL, NULL) + # CRASHES: contains(instance, NULL) + # CRASHES: contains(NULL, object()) + # CRASHES: contains(NULL, NULL) def test_add(self): add = _testcapi.set_add @@ -152,12 +154,12 @@ def test_add(self): self.assertEqual(instance, {1, 2, 3}) with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"): add(instance, []) - # CRASHES: add(NULL, object()) - # CRASHES: add(instance, NULL) - # CRASHES: add(NULL, NULL) with self.assertRaises(SystemError): add(object(), 1) self.assertImmutable(add, 1) + # CRASHES: add(NULL, object()) + # CRASHES: add(instance, NULL) + # CRASHES: add(NULL, NULL) def test_discard(self): discard = _testcapi.set_discard @@ -174,12 +176,12 @@ def test_discard(self): self.assertEqual(instance, set()) with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"): discard(instance, []) - # CRASHES: discard(NULL, object()) - # CRASHES: discard(instance, NULL) - # CRASHES: discard(NULL, NULL) with self.assertRaises(SystemError): discard(object(), 1) self.assertImmutable(discard, 1) + # CRASHES: discard(NULL, object()) + # CRASHES: discard(instance, NULL) + # CRASHES: discard(NULL, NULL) def test_pop(self): pop = _testcapi.set_pop @@ -193,10 +195,10 @@ def test_pop(self): self.assertEqual(len(instance), 0) with self.assertRaises(KeyError): pop(instance) - # CRASHES: pop(NULL) with self.assertRaises(SystemError): pop(object()) self.assertImmutable(pop) + # CRASHES: pop(NULL) def test_clear(self): clear = _testcapi.set_clear @@ -207,7 +209,7 @@ def test_clear(self): self.assertEqual(instance, set()) self.assertEqual(clear(instance), 0) self.assertEqual(instance, set()) - # CRASHES: clear(NULL) with self.assertRaises(SystemError): clear(object()) self.assertImmutable(clear) + # CRASHES: clear(NULL)