Skip to content

Commit

Permalink
Merge pull request #175 from jazzband/simplify
Browse files Browse the repository at this point in the history
Simplify middleware and db backend
  • Loading branch information
blag authored Oct 15, 2023
2 parents 5ab2c57 + ef28ea0 commit 8064ec8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 146 deletions.
102 changes: 28 additions & 74 deletions user_sessions/backends/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import logging

from django.contrib import auth
from django.contrib.sessions.backends.base import CreateError, SessionBase
from django.core.exceptions import SuspiciousOperation
from django.db import IntegrityError, router, transaction
from django.utils import timezone
from django.utils.encoding import force_str
from django.contrib.sessions.backends.db import SessionStore as DBStore


class SessionStore(SessionBase):
class SessionStore(DBStore):
"""
Implements database session store.
"""
Expand All @@ -19,89 +13,49 @@ def __init__(self, session_key=None, user_agent=None, ip=None):
self.ip = ip
self.user_id = None

# Used by superclass to get self.model, which is used elsewhere
@classmethod
def get_model_class(cls):
# Avoids a circular import and allows importing SessionStore when
# user_sessions is not in INSTALLED_APPS
from ..models import Session

return Session

def __setitem__(self, key, value):
if key == auth.SESSION_KEY:
self.user_id = value
super().__setitem__(key, value)

def load(self):
try:
s = Session.objects.get(
session_key=self.session_key,
expire_date__gt=timezone.now()
)
self.user_id = s.user_id
# do not overwrite user_agent/ip, as those might have been updated
if self.user_agent != s.user_agent or self.ip != s.ip:
self.modified = True
return self.decode(s.session_data)
except (Session.DoesNotExist, SuspiciousOperation) as e:
if isinstance(e, SuspiciousOperation):
logger = logging.getLogger('django.security.%s' %
e.__class__.__name__)
logger.warning(force_str(e))
self.create()
return {}

def exists(self, session_key):
return Session.objects.filter(session_key=session_key).exists()
# Used in DBStore.load()
def _get_session_from_db(self):
s = super()._get_session_from_db()
self.user_id = s.user_id
# do not overwrite user_agent/ip, as those might have been updated
if self.user_agent != s.user_agent or self.ip != s.ip:
self.modified = True
return s

def create(self):
while True:
self._session_key = self._get_new_session_key()
try:
# Save immediately to ensure we have a unique entry in the
# database.
self.save(must_create=True)
except CreateError:
# Key wasn't unique. Try again.
continue
self.modified = True
self._session_cache = {}
return
super().create()
self._session_cache = {}

def save(self, must_create=False):
# Used in DBStore.save()
def create_model_instance(self, data):
"""
Saves the current session data to the database. If 'must_create' is
True, a database error will be raised if the saving operation doesn't
create a *new* entry (as opposed to possibly updating an existing
entry).
Return a new instance of the session model object, which represents the
current session state. Intended to be used for saving the session data
to the database.
"""
obj = Session(
return self.model(
session_key=self._get_or_create_session_key(),
session_data=self.encode(self._get_session(no_load=must_create)),
session_data=self.encode(data),
expire_date=self.get_expiry_date(),
user_agent=self.user_agent,
user_id=self.user_id,
ip=self.ip,
)
using = router.db_for_write(Session, instance=obj)
try:
with transaction.atomic(using):
obj.save(force_insert=must_create, using=using)
except IntegrityError as e:
if must_create and 'session_key' in str(e):
raise CreateError
raise

def clear(self):
super().clear()
self.user_id = None

def delete(self, session_key=None):
if session_key is None:
if self.session_key is None:
return
session_key = self.session_key
try:
Session.objects.get(session_key=session_key).delete()
except Session.DoesNotExist:
pass

@classmethod
def clear_expired(cls):
Session.objects.filter(expire_date__lt=timezone.now()).delete()


# At bottom to avoid circular import
from ..models import Session # noqa: E402 isort:skip
61 changes: 5 additions & 56 deletions user_sessions/middleware.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,17 @@
import time

from django.conf import settings
from django.utils.cache import patch_vary_headers
from django.utils.http import http_date

try:
from importlib import import_module
except ImportError:
from django.utils.importlib import import_module
from django.contrib.sessions.middleware import (
SessionMiddleware as DjangoSessionMiddleware,
)

try:
from django.utils.deprecation import MiddlewareMixin
except ImportError:
class MiddlewareMixin:
pass


class SessionMiddleware(MiddlewareMixin):
class SessionMiddleware(DjangoSessionMiddleware):
"""
Middleware that provides ip and user_agent to the session store.
"""
def process_request(self, request):
engine = import_module(settings.SESSION_ENGINE)
session_key = request.COOKIES.get(settings.SESSION_COOKIE_NAME, None)
request.session = engine.SessionStore(
request.session = self.SessionStore(
ip=request.META.get('REMOTE_ADDR', ''),
user_agent=request.META.get('HTTP_USER_AGENT', ''),
session_key=session_key
)

def process_response(self, request, response):
"""
If request.session was modified, or if the configuration is to save the
session every time, save the changes and set a session cookie.
"""
try:
accessed = request.session.accessed
modified = request.session.modified
except AttributeError:
pass
else:
if accessed:
patch_vary_headers(response, ('Cookie',))
if modified or settings.SESSION_SAVE_EVERY_REQUEST:
if request.session.get_expire_at_browser_close():
max_age = None
expires = None
else:
max_age = request.session.get_expiry_age()
expires_time = time.time() + max_age
expires = http_date(expires_time)
# Save the session data and refresh the client cookie.
# Skip session save for 500 responses, refs #3881.
if response.status_code != 500:
request.session.save()
response.set_cookie(
settings.SESSION_COOKIE_NAME,
request.session.session_key,
max_age=max_age,
expires=expires,
domain=settings.SESSION_COOKIE_DOMAIN,
path=settings.SESSION_COOKIE_PATH,
secure=settings.SESSION_COOKIE_SECURE or None,
httponly=settings.SESSION_COOKIE_HTTPONLY or None,
samesite=settings.SESSION_COOKIE_SAMESITE,
)
return response
18 changes: 8 additions & 10 deletions user_sessions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from django.db import models
from django.utils.translation import gettext_lazy as _

from .backends.db import SessionStore


class SessionManager(models.Manager):
use_in_migrations = True
Expand Down Expand Up @@ -38,6 +40,12 @@ class Session(models.Model):
primary_key=True)
session_data = models.TextField(_('session data'))
expire_date = models.DateTimeField(_('expiry date'), db_index=True)
user = models.ForeignKey(getattr(settings, 'AUTH_USER_MODEL', 'auth.User'),
null=True, on_delete=models.CASCADE)
user_agent = models.CharField(null=True, blank=True, max_length=200)
last_activity = models.DateTimeField(auto_now=True)
ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP')

objects = SessionManager()

class Meta:
Expand All @@ -46,13 +54,3 @@ class Meta:

def get_decoded(self):
return SessionStore(None, None).decode(self.session_data)

user = models.ForeignKey(getattr(settings, 'AUTH_USER_MODEL', 'auth.User'),
null=True, on_delete=models.CASCADE)
user_agent = models.CharField(null=True, blank=True, max_length=200)
last_activity = models.DateTimeField(auto_now=True)
ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP')


# At bottom to avoid circular import
from .backends.db import SessionStore # noqa: E402 isort:skip
10 changes: 4 additions & 6 deletions user_sessions/urls.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from django.urls import path, re_path
from django.urls import path

from user_sessions.views import SessionDeleteOtherView

from .views import SessionDeleteView, SessionListView
from .views import SessionDeleteOtherView, SessionDeleteView, SessionListView

app_name = 'user_sessions'
urlpatterns = [
Expand All @@ -16,8 +14,8 @@
view=SessionDeleteOtherView.as_view(),
name='session_delete_other',
),
re_path(
r'^account/sessions/(?P<pk>\w+)/delete/$',
path(
'account/sessions/<str:pk>/delete/',
view=SessionDeleteView.as_view(),
name='session_delete',
),
Expand Down

0 comments on commit 8064ec8

Please sign in to comment.