aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Pashev <pashev.igor@gmail.com>2017-07-26 21:09:57 +0300
committerIgor Pashev <pashev.igor@gmail.com>2017-07-26 21:09:57 +0300
commitbb31be8f6072e4dd72c8630c019f7ab5e0bc9fa9 (patch)
treed7b4194f8b6b5e7af76cf59b4130a08153ae44da
parent1123c543bdd438ad40428e7814325a53c819cee2 (diff)
downloadsproxy2-state.tar.gz
[WIP] State in OAuth2 callback should be short-livedstate
-rw-r--r--sproxy2.cabal1
-rw-r--r--src/Sproxy/Application.hs77
-rw-r--r--src/Sproxy/Application/Cookie.hs33
-rw-r--r--src/Sproxy/Application/State.hs30
-rw-r--r--src/Sproxy/Config.hs3
5 files changed, 79 insertions, 65 deletions
diff --git a/sproxy2.cabal b/sproxy2.cabal
index e0f4375..3cd6281 100644
--- a/sproxy2.cabal
+++ b/sproxy2.cabal
@@ -68,6 +68,7 @@ executable sproxy2
, sqlite-simple
, text
, time
+ , transformers
, unix
, unordered-containers
, wai
diff --git a/src/Sproxy/Application.hs b/src/Sproxy/Application.hs
index 3d6598f..2e273ab 100644
--- a/src/Sproxy/Application.hs
+++ b/src/Sproxy/Application.hs
@@ -9,6 +9,9 @@ module Sproxy.Application (
import Blaze.ByteString.Builder (toByteString)
import Blaze.ByteString.Builder.ByteString (fromByteString)
import Control.Exception (Exception, Handler(..), SomeException, catches, displayException)
+import Control.Monad (mzero)
+import Control.Monad.IO.Class (liftIO)
+import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT)
import Data.ByteString (ByteString)
import Data.ByteString as BS (break, intercalate)
import Data.ByteString.Char8 (pack, unpack)
@@ -36,7 +39,6 @@ import Network.HTTP.Types.Status ( Status(..), badGateway502, badRequest400, for
import Network.Socket (NameInfoFlag(NI_NUMERICHOST), getNameInfo)
import Network.Wai.Conduit (sourceRequestBody, responseSource)
import System.FilePath.Glob (Pattern, match)
-import System.Posix.Time (epochTime)
import Text.InterpolatedString.Perl6 (qc)
import Web.Cookie (Cookies, parseCookies, renderCookies)
import qualified Data.Aeson as JSON
@@ -44,7 +46,7 @@ import qualified Network.HTTP.Client as BE
import qualified Network.Wai as W
import qualified Web.Cookie as WC
-import Sproxy.Application.Cookie ( AuthCookie(..), AuthUser,
+import Sproxy.Application.Cookie (AuthUser,
cookieDecode, cookieEncode, getEmail, getEmailUtf8, getFamilyNameUtf8,
getGivenNameUtf8 )
import Sproxy.Application.OAuth2.Common (OAuth2Client(..))
@@ -90,17 +92,15 @@ sproxy key db oa2 backends = logException $ \req resp -> do
Nothing -> notFound "OAuth2 provider" req resp
Just oa2c -> get (oauth2callback key db (provider, oa2c) be) req resp
- ["access"] -> do
- now <- Just <$> epochTime
- case extractCookie key now cookieName req of
+ ["access"] ->
+ extractCookie key cookieName req >>= \case
Nothing -> authenticationRequired key oa2 req resp
Just (authCookie, _) -> post (checkAccess db authCookie) req resp
_ -> notFound "proxy" req resp
- _ -> do
- now <- Just <$> epochTime
- case extractCookie key now cookieName req of
+ _ ->
+ extractCookie key cookieName req >>= \case
Nothing -> authenticationRequired key oa2 req resp
Just cs@(authCookie, _) ->
authorize db cs req >>= \case
@@ -122,7 +122,7 @@ oauth2callback key db (provider, oa2c) be req resp =
case param "state" of
Nothing -> badRequest "missing auth state" req resp
Just state ->
- case State.decode key state of
+ State.decode key state >>= \case
Left msg -> badRequest ("invalid state: " ++ msg) req resp
Right path -> do
au <- oauth2Authenticate oa2c code (redirectURL req provider)
@@ -138,8 +138,24 @@ oauth2callback key db (provider, oa2c) be req resp =
-- XXX: RFC6265: the user agent MUST NOT attach more than one Cookie header field
-extractCookie :: ByteString -> Maybe CTime -> ByteString -> W.Request -> Maybe (AuthCookie, Cookies)
-extractCookie key now name req = do
+extractCookie :: ByteString -> ByteString -> W.Request -> IO (Maybe (AuthUser, Cookies))
+extractCookie key name req = runMaybeT $ do
+ (_, cookies) <- findCookieHeader
+ (auth, others) <- discriminate cookies
+ liftIO $ cookieDecode key auth >>= \case
+ Left err -> do
+ Log.debug ("extract cookie: " ++ show err)
+ mzero
+ Right user -> return (user, others)
+
+ where
+ findCookieHeader =
+ MaybeT . return $ find ((==) hCookie . fst) (W.requestHeaders req)
+ discriminate cs =
+ case partition ((==) name . fst) $ parseCookies cs of
+ ((_, x):_, xs) -> return (x, xs)
+ _ -> mzero
+{-
(_, cookies) <- find ((==) hCookie . fst) $ W.requestHeaders req
(auth, others) <- discriminate cookies
case cookieDecode key auth of
@@ -150,21 +166,19 @@ extractCookie key now name req = do
case partition ((==) name . fst) $ parseCookies cs of
((_, x):_, xs) -> Just (x, xs)
_ -> Nothing
-
+-}
authenticate :: ByteString -> BackendConf -> AuthUser -> ByteString -> W.Application
authenticate key be user path req resp = do
- now <- epochTime
+ (authCookie, expiry) <- cookieEncode key (beCookieMaxAge be) user
let domain = pack <$> beCookieDomain be
- expiry = now + CTime (beCookieMaxAge be)
- authCookie = AuthCookie { acUser = user, acExpiry = expiry }
cookie = WC.def {
WC.setCookieName = pack $ beCookieName be
, WC.setCookieHttpOnly = True
, WC.setCookiePath = Just "/"
, WC.setCookieSameSite = Nothing
, WC.setCookieSecure = True
- , WC.setCookieValue = cookieEncode key authCookie
+ , WC.setCookieValue = authCookie
, WC.setCookieDomain = domain
, WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ expiry
}
@@ -174,10 +188,9 @@ authenticate key be user path req resp = do
] ""
-authorize :: Database -> (AuthCookie, Cookies) -> W.Request -> IO (Maybe W.Request)
-authorize db (authCookie, otherCookies) req = do
+authorize :: Database -> (AuthUser, Cookies) -> W.Request -> IO (Maybe W.Request)
+authorize db (user, otherCookies) req = do
let
- user = acUser authCookie
domain = decodeUtf8 . fromJust $ requestDomain req
email = getEmail user
emailUtf8 = getEmailUtf8 user
@@ -206,9 +219,9 @@ authorize db (authCookie, otherCookies) req = do
setCookies cs = insert hCookie (toByteString . renderCookies $ cs)
-checkAccess :: Database -> AuthCookie -> W.Application
-checkAccess db authCookie req resp = do
- let email = getEmail . acUser $ authCookie
+checkAccess :: Database -> AuthUser -> W.Application
+checkAccess db user req resp = do
+ let email = getEmail user
domain = decodeUtf8 . fromJust $ requestDomain req
body <- W.strictRequestBody req
case JSON.eitherDecode' body of
@@ -275,12 +288,11 @@ modifyResponseHeaders = filter (\(n, _) -> n `notElem` ban)
authenticationRequired :: ByteString -> HashMap Text OAuth2Client -> W.Application
authenticationRequired key oa2 req resp = do
Log.info $ "511 Unauthenticated: " ++ showReq req
- resp $ W.responseLBS networkAuthenticationRequired511 [(hContentType, "text/html; charset=utf-8")] page
- where
- path = if W.requestMethod req == methodGet
- then W.rawPathInfo req <> W.rawQueryString req
- else "/"
- state = State.encode key path
+ (state, _) <- State.encode key 60
+ (if W.requestMethod req == methodGet
+ then W.rawPathInfo req <> W.rawQueryString req
+ else "/")
+ let
authLink :: Text -> OAuth2Client -> ByteString -> ByteString
authLink provider oa2c html =
let u = oauth2AuthorizeURL oa2c state (redirectURL req provider)
@@ -300,14 +312,15 @@ authenticationRequired key oa2 req resp = do
</body>
</html>
|]
+ resp $ W.responseLBS networkAuthenticationRequired511 [(hContentType, "text/html; charset=utf-8")] page
-forbidden :: AuthCookie -> W.Application
-forbidden ac req resp = do
+forbidden :: AuthUser -> W.Application
+forbidden user req resp = do
Log.info $ "403 Forbidden: " ++ show email ++ ": " ++ showReq req
resp $ W.responseLBS forbidden403 [(hContentType, "text/html; charset=utf-8")] page
where
- email = getEmailUtf8 . acUser $ ac
+ email = getEmailUtf8 user
page = fromStrict [qc|
<!DOCTYPE html>
<html lang="en">
@@ -349,7 +362,7 @@ userNotFound au _ resp = do
logout :: ByteString -> ByteString -> Maybe ByteString -> W.Application
logout key cookieName cookieDomain req resp = do
let host = fromJust $ W.requestHeaderHost req
- case extractCookie key Nothing cookieName req of
+ extractCookie key cookieName req >>= \case
Nothing -> resp $ W.responseLBS found302 [ (hLocation, "https://" <> host) ] ""
Just _ -> do
let cookie = WC.def {
diff --git a/src/Sproxy/Application/Cookie.hs b/src/Sproxy/Application/Cookie.hs
index a86f42a..5bd15ef 100644
--- a/src/Sproxy/Application/Cookie.hs
+++ b/src/Sproxy/Application/Cookie.hs
@@ -1,7 +1,6 @@
{-# LANGUAGE OverloadedStrings #-}
module Sproxy.Application.Cookie (
- AuthCookie(..)
-, AuthUser
+ AuthUser
, cookieDecode
, cookieEncode
, getEmail
@@ -16,7 +15,7 @@ module Sproxy.Application.Cookie (
import Data.ByteString (ByteString)
import Data.Text (Text, toLower, strip)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
-import Foreign.C.Types (CTime(..))
+import Foreign.C.Types (CTime)
import qualified Data.Serialize as DS
import qualified Sproxy.Application.State as State
@@ -27,28 +26,20 @@ data AuthUser = AuthUser {
, auFamilyName :: ByteString
}
-data AuthCookie = AuthCookie {
- acUser :: AuthUser
-, acExpiry :: CTime
-}
-
-instance DS.Serialize AuthCookie where
- put c = DS.put (auEmail u, auGivenName u, auFamilyName u, x)
- where u = acUser c
- x = (\(CTime i) -> i) $ acExpiry c
+instance DS.Serialize AuthUser where
+ put u = DS.put (auEmail u, auGivenName u, auFamilyName u)
get = do
- (e, n, f, x) <- DS.get
- return AuthCookie {
- acUser = AuthUser { auEmail = e, auGivenName = n, auFamilyName = f }
- , acExpiry = CTime x
- }
+ (e, n, f) <- DS.get
+ return AuthUser { auEmail = e, auGivenName = n, auFamilyName = f }
-cookieDecode :: ByteString -> ByteString -> Either String AuthCookie
-cookieDecode key d = State.decode key d >>= DS.decode
+cookieDecode :: ByteString -> ByteString -> IO (Either String AuthUser)
+cookieDecode key d = do
+ c <- State.decode key d
+ return $ c >>= DS.decode
-cookieEncode :: ByteString -> AuthCookie -> ByteString
-cookieEncode key = State.encode key . DS.encode
+cookieEncode :: ByteString -> Int -> AuthUser -> IO (ByteString, CTime)
+cookieEncode key shelflife = State.encode key (fromIntegral shelflife) . DS.encode
getEmail :: AuthUser -> Text
diff --git a/src/Sproxy/Application/State.hs b/src/Sproxy/Application/State.hs
index 29d9252..8ddbedf 100644
--- a/src/Sproxy/Application/State.hs
+++ b/src/Sproxy/Application/State.hs
@@ -6,6 +6,8 @@ module Sproxy.Application.State (
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (fromStrict, toStrict)
import Data.Digest.Pure.SHA (hmacSha1, bytestringDigest)
+import Foreign.C.Types (CTime(..))
+import System.Posix.Time (epochTime)
import qualified Data.ByteString.Base64 as Base64
import qualified Data.Serialize as DS
@@ -13,16 +15,24 @@ import qualified Data.Serialize as DS
-- FIXME: Compress / decompress ?
-encode :: ByteString -> ByteString -> ByteString
-encode key payload = Base64.encode . DS.encode $ (payload, digest key payload)
-
-
-decode :: ByteString -> ByteString -> Either String ByteString
-decode key d = do
- (payload, dgst) <- DS.decode =<< Base64.decode d
- if dgst /= digest key payload
- then Left "junk"
- else Right payload
+encode :: ByteString -> Int -> ByteString -> IO (ByteString, CTime)
+encode key shelflife payload = do
+ now <- epochTime
+ let expiry = now + (CTime . fromIntegral $ shelflife)
+ d = DS.encode (payload, (\(CTime i64) -> i64) expiry)
+ return (Base64.encode . DS.encode $ (d, digest key d), expiry)
+
+
+decode :: ByteString -> ByteString -> IO (Either String ByteString)
+decode key raw = do
+ (CTime now) <- epochTime
+ return $ do
+ (d, dgst) <- DS.decode =<< Base64.decode raw
+ if dgst /= digest key d then Left "junk"
+ else do
+ (payload, expiry) <- DS.decode d
+ if expiry < now then Left "expired"
+ else Right payload
digest :: ByteString -> ByteString -> ByteString
diff --git a/src/Sproxy/Config.hs b/src/Sproxy/Config.hs
index e0f35a3..b011680 100644
--- a/src/Sproxy/Config.hs
+++ b/src/Sproxy/Config.hs
@@ -8,7 +8,6 @@ module Sproxy.Config (
import Control.Applicative (empty)
import Data.Aeson (FromJSON, parseJSON)
import Data.HashMap.Strict (HashMap)
-import Data.Int (Int64)
import Data.Text (Text)
import Data.Word (Word16)
import Data.Yaml (Value(Object), (.:), (.:?), (.!=))
@@ -64,7 +63,7 @@ data BackendConf = BackendConf {
, beSocket :: Maybe FilePath
, beCookieName :: String
, beCookieDomain :: Maybe String
-, beCookieMaxAge :: Int64
+, beCookieMaxAge :: Int
, beConnCount :: Int
} deriving (Show)