|
| 1 | +""" |
| 2 | +Dependencies for FastAPI application |
| 3 | +
|
| 4 | +How it works: |
| 5 | + *Declare dependencies* |
| 6 | + - Define regular functions, these functions can take parameters, including other dependencies |
| 7 | +
|
| 8 | + *Inject dependencies* |
| 9 | + - Inject dependencies into routes using the `Depends` function (decorated with `@` like @app.get, @app.post) |
| 10 | +
|
| 11 | +Benefits: |
| 12 | + - Improved Code Organization: Separates concerns by allowing you to extract shared logic into reusable dependency functions. |
| 13 | + - Enhanced Testability: Dependencies can be easily mocked or replaced during testing. |
| 14 | + - Reusability: Dependencies can be used across multiple routes and endpoints. |
| 15 | + - Reduced duplication of code: Dependencies can be used across multiple routes and endpoints. |
| 16 | + - Better maintainability: Centralizes dependency creation and management, simplifying updates and changes. |
| 17 | + - Automatic Handling: Dependencies can be automatically handled by FastAPI, such as dependency injection and error handling. |
| 18 | +
|
| 19 | +""" |
| 20 | +from fastapi import Depends, HTTPException, status, Header, Query |
| 21 | +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| 22 | +from typing import Optional, List |
| 23 | +import jwt |
| 24 | +from datetime import datetime, timedelta |
| 25 | +import hashlib |
| 26 | + |
| 27 | +# Security |
| 28 | +security = HTTPBearer() |
| 29 | + |
| 30 | +# Mock user database |
| 31 | +users_db = { |
| 32 | + "admin": { |
| 33 | + "id": 1, |
| 34 | + "username": "admin", |
| 35 | + "email": "admin@example.com", |
| 36 | + "hashed_password": hashlib.sha256("admin123".encode()).hexdigest(), |
| 37 | + "roles": ["admin", "user"], |
| 38 | + "is_active": True |
| 39 | + }, |
| 40 | + "user": { |
| 41 | + "id": 2, |
| 42 | + "username": "jay", |
| 43 | + "email": "jay@example.com", |
| 44 | + "hashed_password": hashlib.sha256("user123".encode()).hexdigest(), |
| 45 | + "roles": ["user"], |
| 46 | + "is_active": True |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +SECRET_KEY = "python-secret-key" # In real application, this should be a secret key stored securely |
| 51 | +ALGORITHM = "HS256" |
| 52 | + |
| 53 | +class User: |
| 54 | + def __init__(self, id: int, username: str, email: str, roles: List[str], is_active: bool): |
| 55 | + self.id = id |
| 56 | + self.username = username |
| 57 | + self.email = email |
| 58 | + self.roles = roles |
| 59 | + self.is_active = is_active |
| 60 | + |
| 61 | +# Database dependency |
| 62 | +class DatabaseConnection: |
| 63 | + def __init__(self): |
| 64 | + self.connected = True # Mocked for demonstration purposes, this should be replaced with actual database connection logic |
| 65 | + self.connection_id = f"conn_{datetime.now().timestamp()}" |
| 66 | + |
| 67 | + def close(self): |
| 68 | + self.connected = False |
| 69 | + |
| 70 | +def get_database(): |
| 71 | + db = DatabaseConnection() |
| 72 | + try: |
| 73 | + yield db |
| 74 | + finally: |
| 75 | + db.close() |
| 76 | + |
| 77 | +# Authentication dependencies |
| 78 | +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): |
| 79 | + to_encode = data.copy() |
| 80 | + if expires_delta: |
| 81 | + expire = datetime.now() + expires_delta |
| 82 | + else: |
| 83 | + expire = datetime.now() + timedelta(minutes=15) |
| 84 | + to_encode.update({"exp": expire}) |
| 85 | + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
| 86 | + return encoded_jwt |
| 87 | + |
| 88 | +def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| 89 | + try: |
| 90 | + payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) |
| 91 | + username: str = payload.get("sub") |
| 92 | + print(f"Token decoded: username={username}, payload={payload}") # Debug output |
| 93 | + if username is None: |
| 94 | + raise HTTPException( |
| 95 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 96 | + detail="Could not validate credentials", |
| 97 | + headers={"WWW-Authenticate": "Bearer"}, |
| 98 | + ) |
| 99 | + return username |
| 100 | + except jwt.PyJWTError as e: |
| 101 | + print(f"Token validation failed: {str(e)}") # Debug output |
| 102 | + raise HTTPException( |
| 103 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 104 | + detail="Could not validate credentials", |
| 105 | + headers={"WWW-Authenticate": "Bearer"}, |
| 106 | + ) |
| 107 | + |
| 108 | +""" |
| 109 | +In this get_current_user, verify_token is the dependency that verifies the token and returns the username. |
| 110 | +""" |
| 111 | +def get_current_user(username: str = Depends(verify_token)) -> User: |
| 112 | + user_data = users_db.get(username) |
| 113 | + if user_data is None: |
| 114 | + raise HTTPException( |
| 115 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 116 | + detail="User not found" |
| 117 | + ) |
| 118 | + if not user_data["is_active"]: |
| 119 | + raise HTTPException( |
| 120 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 121 | + detail="Inactive user" |
| 122 | + ) |
| 123 | + user_data_without_password = { |
| 124 | + "id": user_data["id"], |
| 125 | + "username": user_data["username"], |
| 126 | + "email": user_data["email"], |
| 127 | + "roles": user_data["roles"], |
| 128 | + "is_active": user_data["is_active"] |
| 129 | + } |
| 130 | + return User(**user_data_without_password) |
| 131 | + |
| 132 | +def get_admin_user(current_user: User = Depends(get_current_user)) -> User: |
| 133 | + if "admin" not in current_user.roles: |
| 134 | + raise HTTPException( |
| 135 | + status_code=status.HTTP_403_FORBIDDEN, |
| 136 | + detail="Not enough permissions" |
| 137 | + ) |
| 138 | + return current_user |
| 139 | + |
| 140 | +# Pagination dependency |
| 141 | +class PaginationParams: |
| 142 | + def __init__( |
| 143 | + self, |
| 144 | + skip: int = Query(0, ge=0, description="Number of records to skip"), |
| 145 | + limit: int = Query(10, ge=1, le=100, description="Number of records to return") |
| 146 | + ): |
| 147 | + self.skip = skip |
| 148 | + self.limit = limit |
| 149 | + |
| 150 | +# Sorting dependency |
| 151 | +class SortingParams: |
| 152 | + def __init__( |
| 153 | + self, |
| 154 | + sort_by: str = Query("id", description="Field to sort by"), |
| 155 | + sort_order: str = Query("asc", patterns="^(asc|desc)$", description="Sort order") |
| 156 | + ): |
| 157 | + self.sort_by = sort_by |
| 158 | + self.sort_order = sort_order |
| 159 | + |
| 160 | +# Rate limiting dependency |
| 161 | +class RateLimiter: |
| 162 | + def __init__(self): |
| 163 | + self.requests = {} |
| 164 | + |
| 165 | + def is_allowed(self, client_ip: str, limit: int = 100, window: int = 3600) -> bool: |
| 166 | + now = datetime.now() |
| 167 | + if client_ip not in self.requests: |
| 168 | + self.requests[client_ip] = [] |
| 169 | + |
| 170 | + # Clean old requests |
| 171 | + self.requests[client_ip] = [ |
| 172 | + req_time for req_time in self.requests[client_ip] |
| 173 | + if (now - req_time).seconds < window |
| 174 | + ] |
| 175 | + |
| 176 | + if len(self.requests[client_ip]) >= limit: |
| 177 | + return False |
| 178 | + |
| 179 | + self.requests[client_ip].append(now) |
| 180 | + return True |
| 181 | + |
| 182 | +rate_limiter = RateLimiter() |
| 183 | + |
| 184 | +def check_rate_limit( |
| 185 | + x_forwarded_for: Optional[str] = Header(None), |
| 186 | + x_real_ip: Optional[str] = Header(None) |
| 187 | +): |
| 188 | + client_ip = x_forwarded_for or x_real_ip or "127.0.0.1" |
| 189 | + if not rate_limiter.is_allowed(client_ip): |
| 190 | + raise HTTPException( |
| 191 | + status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
| 192 | + detail="Rate limit exceeded" |
| 193 | + ) |
| 194 | + return client_ip |
| 195 | + |
| 196 | +# Validation dependencies |
| 197 | +def validate_positive_int(value: int) -> int: |
| 198 | + if value <= 0: |
| 199 | + raise HTTPException( |
| 200 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 201 | + detail="Value must be positive" |
| 202 | + ) |
| 203 | + return value |
| 204 | + |
| 205 | +def get_item_id(item_id: int) -> int: |
| 206 | + return validate_positive_int(item_id) |
| 207 | + |
| 208 | +# Common query parameters |
| 209 | +class CommonQueryParams: |
| 210 | + def __init__( |
| 211 | + self, |
| 212 | + q: Optional[str] = Query(None, min_length=1, max_length=50, description="Search query"), |
| 213 | + include_inactive: bool = Query(False, description="Include inactive items") |
| 214 | + ): |
| 215 | + self.q = q |
| 216 | + self.include_inactive = include_inactive |
0 commit comments