Skip to content

Commit 440488e

Browse files
authored
Merge pull request #15 from zazayaya/master
add owner loader
2 parents be3596f + 1dfa66c commit 440488e

2 files changed

Lines changed: 64 additions & 0 deletions

File tree

flask_authz/casbin_enforcer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(self, app, adapter, watcher=None):
2727
self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter, True)
2828
if watcher:
2929
self.e.set_watcher(watcher)
30+
self._owner_loader = None
3031

3132
def set_watcher(self, watcher):
3233
"""
@@ -38,6 +39,17 @@ def set_watcher(self, watcher):
3839
"""
3940
self.e.set_watcher(watcher)
4041

42+
def owner_loader(self, callback):
43+
"""
44+
This sets the callback for get owner. The
45+
function return a owner object, or ``None``
46+
47+
:param callback: The callback for retrieving a owner object.
48+
:type callback: callable
49+
"""
50+
self._owner_loader = callback
51+
return callback
52+
4153
def enforcer(self, func):
4254
@wraps(func)
4355
def wrapper(*args, **kwargs):
@@ -50,6 +62,14 @@ def wrapper(*args, **kwargs):
5062
)
5163
# Set resource URI from request
5264
uri = str(request.path)
65+
# Get owner from owner_loader
66+
if self._owner_loader:
67+
self.app.logger.info("Get owner from owner_loader")
68+
for owner in self._owner_loader():
69+
if self.e.enforce(
70+
owner.strip('"'), uri, request.method
71+
):
72+
return func(*args, **kwargs)
5373
for header in self.app.config.get("CASBIN_OWNER_HEADERS"):
5474
if header in request.headers:
5575
# Make Authorization Header Parser standard

tests/test_casbin_enforcer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,47 @@ def test_enforcer_set_watcher(enforcer, watcher):
147147
assert enforcer.e.watcher is None
148148
enforcer.set_watcher(watcher())
149149
assert isinstance(enforcer.e.watcher, watcher)
150+
151+
152+
@pytest.mark.parametrize(
153+
"owner, method, status",
154+
[
155+
(["alice"], "GET", 200),
156+
(["alice"], "POST", 201),
157+
(["alice"], "DELETE", 202),
158+
(["bob"], "GET", 200),
159+
(["bob"], "POST", 401),
160+
(["bob"], "DELETE", 401),
161+
(["admin"], "GET", 401),
162+
(["users"], "GET", 200),
163+
(["alice", "bob"], "POST", 201),
164+
(["noexist", "testnoexist"], "POST", 401),
165+
],
166+
)
167+
def test_enforcer_with_owner_loader(app_fixture, enforcer, owner, method, status):
168+
@app_fixture.route("/")
169+
@enforcer.enforcer
170+
def index():
171+
return jsonify({"message": "passed"}), 200
172+
173+
@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
174+
@enforcer.enforcer
175+
def item():
176+
if request.method == "GET":
177+
return jsonify({"message": "passed"}), 200
178+
elif request.method == "POST":
179+
return jsonify({"message": "passed"}), 201
180+
elif request.method == "DELETE":
181+
return jsonify({"message": "passed"}), 202
182+
183+
@enforcer.owner_loader
184+
def owner_loader():
185+
return owner
186+
187+
c = app_fixture.test_client()
188+
# c.post('/add', data=dict(title='2nd Item', text='The text'))
189+
rv = c.get("/")
190+
assert rv.status_code == 401
191+
caller = getattr(c, method.lower())
192+
rv = caller("/item")
193+
assert rv.status_code == status

0 commit comments

Comments
 (0)