feat: implement user identity middleware for session management

- Added UserIdentityMiddleware to manage user sessions by generating a unique user ID stored in a cookie.
- Implemented logic to update the last seen timestamp for existing users in the database.
- Enhanced ScaleResult model to associate responses with users, improving data tracking and user experience.
This commit is contained in:
Miu Li 2025-06-17 15:44:10 +08:00
parent 27f2479108
commit cfb7ddedd3
2 changed files with 68 additions and 3 deletions

44
app.py
View File

@ -9,7 +9,7 @@ import uvicorn
from datetime import datetime from datetime import datetime
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from database import get_db, ScaleResult from database import get_db, ScaleResult, new_user, User
import geoip2.database import geoip2.database
from datetime import datetime, UTC from datetime import datetime, UTC
import csv import csv
@ -100,8 +100,49 @@ class LanguageMiddleware(BaseHTTPMiddleware):
return response return response
class UserIdentityMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
# Get user_id from cookie
user_id = request.cookies.get("user_id")
# If no user_id in cookie, generate a new one
if not user_id:
user_id = new_user()
else:
# Update last_seen for existing user
db = next(get_db())
try:
user = db.query(User).filter(User.id == int(user_id)).first()
if user:
user.last_seen = datetime.now(UTC)
db.commit()
finally:
db.close()
# Add user_id to request state
request.state.user_id = user_id
# Continue processing the request
response = await call_next(request)
# Set cookie if it's not already set
if not request.cookies.get("user_id"):
response.set_cookie(
key="user_id",
value=user_id,
max_age=None, # Cookie will never expire
httponly=True,
samesite="lax"
)
return response
app = FastAPI() app = FastAPI()
app.add_middleware(LanguageMiddleware) app.add_middleware(LanguageMiddleware)
app.add_middleware(UserIdentityMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
templates = {} templates = {}
for lang in os.listdir("templates"): for lang in os.listdir("templates"):
@ -200,6 +241,7 @@ async def result(request: Request, scale_id: str, db: Session = Depends(get_db))
location = get_location_from_ip(ip)# Get location information location = get_location_from_ip(ip)# Get location information
db_response = ScaleResult( db_response = ScaleResult(
scale_id=scale_id, scale_id=scale_id,
user_id=request.state.user_id,
user_agent=request.headers.get("user-agent", "Unknown"), user_agent=request.headers.get("user-agent", "Unknown"),
ip_address=ip, ip_address=ip,
location=location, location=location,

View File

@ -1,7 +1,8 @@
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, ForeignKey, JSON from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, ForeignKey, JSON, func
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.orm import sessionmaker, relationship
import json import json
from datetime import datetime, UTC
SQLALCHEMY_DATABASE_URL = "sqlite:///./psychoscales.db" SQLALCHEMY_DATABASE_URL = "sqlite:///./psychoscales.db"
@ -14,9 +15,17 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
created_at = Column(DateTime)
last_seen = Column(DateTime)
responses = relationship("ScaleResult", back_populates="user")
class ScaleResult(Base): class ScaleResult(Base):
__tablename__ = "responses" __tablename__ = "responses"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"))
scale_id = Column(String, index=True) scale_id = Column(String, index=True)
user_agent = Column(String) user_agent = Column(String)
ip_address = Column(String) ip_address = Column(String)
@ -25,6 +34,7 @@ class ScaleResult(Base):
sum_response = Column(JSON) sum_response = Column(JSON)
avg_response = Column(JSON) avg_response = Column(JSON)
created_at = Column(DateTime) created_at = Column(DateTime)
user = relationship("User", back_populates="responses")
# Create tables # Create tables
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@ -36,3 +46,16 @@ def get_db():
yield db yield db
finally: finally:
db.close() db.close()
def new_user() -> int:
db = SessionLocal()
try:
with db.begin():
user = User()
user.last_seen = user.created_at = datetime.now(UTC)
db.add(user)
db.flush()
return user.id
finally:
db.close()