Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify #175

Merged
merged 4 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 25 additions & 74 deletions user_sessions/backends/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import logging

from django.contrib import auth
from django.contrib.sessions.backends.base import CreateError, SessionBase
from django.core.exceptions import SuspiciousOperation
from django.db import IntegrityError, router, transaction
from django.utils import timezone
from django.utils.encoding import force_str
from django.contrib.sessions.backends.db import SessionStore as DBStore


class SessionStore(SessionBase):
class SessionStore(DBStore):
"""
Implements database session store.
"""
Expand All @@ -19,89 +13,46 @@ def __init__(self, session_key=None, user_agent=None, ip=None):
self.ip = ip
self.user_id = None

@classmethod
def get_model_class(cls):
# Avoids a circular import and allows importing SessionStore when
# user_sessions is not in INSTALLED_APPS
from ..models import Session

return Session

def __setitem__(self, key, value):
if key == auth.SESSION_KEY:
self.user_id = value
super().__setitem__(key, value)

def load(self):
try:
s = Session.objects.get(
session_key=self.session_key,
expire_date__gt=timezone.now()
)
self.user_id = s.user_id
# do not overwrite user_agent/ip, as those might have been updated
if self.user_agent != s.user_agent or self.ip != s.ip:
self.modified = True
return self.decode(s.session_data)
except (Session.DoesNotExist, SuspiciousOperation) as e:
if isinstance(e, SuspiciousOperation):
logger = logging.getLogger('django.security.%s' %
e.__class__.__name__)
logger.warning(force_str(e))
self.create()
return {}

def exists(self, session_key):
return Session.objects.filter(session_key=session_key).exists()
def _get_session_from_db(self):
s = super()._get_session_from_db()
self.user_id = s.user_id
# do not overwrite user_agent/ip, as those might have been updated
WhyNotHugo marked this conversation as resolved.
Show resolved Hide resolved
if self.user_agent != s.user_agent or self.ip != s.ip:
self.modified = True
return s

def create(self):
while True:
self._session_key = self._get_new_session_key()
try:
# Save immediately to ensure we have a unique entry in the
# database.
self.save(must_create=True)
except CreateError:
# Key wasn't unique. Try again.
continue
self.modified = True
self._session_cache = {}
return
super().create()
self._session_cache = {}

def save(self, must_create=False):
def create_model_instance(self, data):
"""
Saves the current session data to the database. If 'must_create' is
True, a database error will be raised if the saving operation doesn't
create a *new* entry (as opposed to possibly updating an existing
entry).
Return a new instance of the session model object, which represents the
current session state. Intended to be used for saving the session data
to the database.
"""
obj = Session(
return self.model(
session_key=self._get_or_create_session_key(),
session_data=self.encode(self._get_session(no_load=must_create)),
session_data=self.encode(data),
expire_date=self.get_expiry_date(),
user_agent=self.user_agent,
user_id=self.user_id,
ip=self.ip,
)
using = router.db_for_write(Session, instance=obj)
try:
with transaction.atomic(using):
obj.save(force_insert=must_create, using=using)
except IntegrityError as e:
if must_create and 'session_key' in str(e):
raise CreateError
raise

def clear(self):
super().clear()
self.user_id = None

def delete(self, session_key=None):
if session_key is None:
if self.session_key is None:
return
session_key = self.session_key
try:
Session.objects.get(session_key=session_key).delete()
except Session.DoesNotExist:
pass

@classmethod
def clear_expired(cls):
Session.objects.filter(expire_date__lt=timezone.now()).delete()


# At bottom to avoid circular import
from ..models import Session # noqa: E402 isort:skip
61 changes: 5 additions & 56 deletions user_sessions/middleware.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,17 @@
import time

from django.conf import settings
from django.utils.cache import patch_vary_headers
from django.utils.http import http_date

try:
from importlib import import_module
except ImportError:
from django.utils.importlib import import_module
from django.contrib.sessions.middleware import (
SessionMiddleware as DjangoSessionMiddleware,
)

try:
from django.utils.deprecation import MiddlewareMixin
except ImportError:
class MiddlewareMixin:
pass


class SessionMiddleware(MiddlewareMixin):
class SessionMiddleware(DjangoSessionMiddleware):
"""
Middleware that provides ip and user_agent to the session store.
"""
def process_request(self, request):
engine = import_module(settings.SESSION_ENGINE)
session_key = request.COOKIES.get(settings.SESSION_COOKIE_NAME, None)
request.session = engine.SessionStore(
request.session = self.SessionStore(
ip=request.META.get('REMOTE_ADDR', ''),
WhyNotHugo marked this conversation as resolved.
Show resolved Hide resolved
user_agent=request.META.get('HTTP_USER_AGENT', ''),
session_key=session_key
)

def process_response(self, request, response):
"""
If request.session was modified, or if the configuration is to save the
session every time, save the changes and set a session cookie.
"""
try:
accessed = request.session.accessed
modified = request.session.modified
except AttributeError:
pass
else:
if accessed:
patch_vary_headers(response, ('Cookie',))
if modified or settings.SESSION_SAVE_EVERY_REQUEST:
if request.session.get_expire_at_browser_close():
max_age = None
expires = None
else:
max_age = request.session.get_expiry_age()
expires_time = time.time() + max_age
expires = http_date(expires_time)
# Save the session data and refresh the client cookie.
# Skip session save for 500 responses, refs #3881.
if response.status_code != 500:
request.session.save()
response.set_cookie(
settings.SESSION_COOKIE_NAME,
request.session.session_key,
max_age=max_age,
expires=expires,
domain=settings.SESSION_COOKIE_DOMAIN,
path=settings.SESSION_COOKIE_PATH,
secure=settings.SESSION_COOKIE_SECURE or None,
httponly=settings.SESSION_COOKIE_HTTPONLY or None,
samesite=settings.SESSION_COOKIE_SAMESITE,
)
return response
18 changes: 8 additions & 10 deletions user_sessions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from django.db import models
from django.utils.translation import gettext_lazy as _

from .backends.db import SessionStore


class SessionManager(models.Manager):
use_in_migrations = True
Expand Down Expand Up @@ -38,6 +40,12 @@ class Session(models.Model):
primary_key=True)
session_data = models.TextField(_('session data'))
expire_date = models.DateTimeField(_('expiry date'), db_index=True)
user = models.ForeignKey(getattr(settings, 'AUTH_USER_MODEL', 'auth.User'),
null=True, on_delete=models.CASCADE)
user_agent = models.CharField(null=True, blank=True, max_length=200)
last_activity = models.DateTimeField(auto_now=True)
ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP')

objects = SessionManager()

class Meta:
Expand All @@ -46,13 +54,3 @@ class Meta:

def get_decoded(self):
return SessionStore(None, None).decode(self.session_data)

user = models.ForeignKey(getattr(settings, 'AUTH_USER_MODEL', 'auth.User'),
null=True, on_delete=models.CASCADE)
user_agent = models.CharField(null=True, blank=True, max_length=200)
last_activity = models.DateTimeField(auto_now=True)
ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP')


# At bottom to avoid circular import
from .backends.db import SessionStore # noqa: E402 isort:skip
10 changes: 4 additions & 6 deletions user_sessions/urls.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from django.urls import path, re_path
from django.urls import path

from user_sessions.views import SessionDeleteOtherView

from .views import SessionDeleteView, SessionListView
from .views import SessionDeleteOtherView, SessionDeleteView, SessionListView

app_name = 'user_sessions'
urlpatterns = [
Expand All @@ -16,8 +14,8 @@
view=SessionDeleteOtherView.as_view(),
name='session_delete_other',
),
re_path(
r'^account/sessions/(?P<pk>\w+)/delete/$',
path(
'account/sessions/<str:pk>/delete/',
view=SessionDeleteView.as_view(),
name='session_delete',
),
Expand Down
Loading