aboutsummaryrefslogtreecommitdiff
path: root/src/Sproxy/Server.hs
blob: d5e396c20c624f9c5e2006ab4638cb5f165a906f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
module Sproxy.Server
  ( server
  ) where

import Control.Concurrent (forkIO)
import Control.Exception (bracketOnError)
import Control.Monad (void, when)
import Data.ByteString.Char8 (pack)
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Yaml.Include (decodeFileEither)
import Network.HTTP.Client
  ( Manager
  , ManagerSettings(..)
  , defaultManagerSettings
  , newManager
  , responseTimeoutMicro
  , socketConnection
  )
import Network.HTTP.Client.Internal (Connection)
import Network.Socket
  ( AddrInfoFlag(AI_NUMERICSERV)
  , Family(AF_INET, AF_UNIX)
  , SockAddr(SockAddrInet, SockAddrUnix)
  , Socket
  , SocketOption(ReuseAddr)
  , SocketType(Stream)
  , addrAddress
  , addrFamily
  , addrFlags
  , addrProtocol
  , addrSocketType
  , bind
  , close
  , connect
  , defaultHints
  , getAddrInfo
  , listen
  , maxListenQueue
  , setSocketOption
  , socket
  )
import Network.Wai (Application)
import Network.Wai.Handler.Warp
  ( Settings
  , defaultSettings
  , runSettingsSocket
  , setHTTP2Disabled
  , setOnException
  )
import Network.Wai.Handler.WarpTLS (runTLSSocket, tlsSettingsChain)
import System.Entropy (getEntropy)
import System.Environment (setEnv)
import System.Exit (exitFailure)
import System.FilePath.Glob (compile)
import System.IO (hPutStrLn, stderr)
import System.Posix.User
  ( GroupEntry(..)
  , UserEntry(..)
  , getAllGroupEntries
  , getRealUserID
  , getUserEntryForName
  , setGroupID
  , setGroups
  , setUserID
  )

import Sproxy.Application (redirect, sproxy)
import qualified Sproxy.Application.OAuth2 as OAuth2
import Sproxy.Application.OAuth2.Common (OAuth2Client)
import Sproxy.Config (BackendConf(..), ConfigFile(..), OAuth2Conf(..))
import qualified Sproxy.Logging as Log
import qualified Sproxy.Server.DB as DB

{- TODO:
 - Log.error && exitFailure should be replaced
 - by Log.fatal && wait for logger thread to print && exitFailure
-}
server :: FilePath -> IO ()
server configFile = do
  cf <- readConfigFile configFile
  Log.start $ cfLogLevel cf
  sock <- socket AF_INET Stream 0
  setSocketOption sock ReuseAddr 1
  bind sock $ SockAddrInet (fromIntegral $ cfListen cf) 0
  maybe80 <-
    if fromMaybe (443 == cfListen cf) (cfListen80 cf)
      then do
        sock80 <- socket AF_INET Stream 0
        setSocketOption sock80 ReuseAddr 1
        bind sock80 $ SockAddrInet 80 0
        return (Just sock80)
      else return Nothing
  uid <- getRealUserID
  when (0 == uid) $ do
    let user = cfUser cf
    Log.info $ "switching to user " ++ show user
    u <- getUserEntryForName user
    groupIDs <-
      map groupID . filter (elem user . groupMembers) <$> getAllGroupEntries
    setGroups groupIDs
    setGroupID $ userGroupID u
    setUserID $ userID u
  ds <- newDataSource cf
  db <- DB.start (cfHome cf) ds
  key <-
    maybe
      (Log.info "using new random key" >> getEntropy 64)
      (return . pack)
      (cfKey cf)
  let settings =
        (if cfHTTP2 cf
           then id
           else setHTTP2Disabled) $
        setOnException (\_ _ -> return ()) defaultSettings
  oauth2clients <-
    HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf))
  backends <-
    mapM
      (\be -> do
         m <- newBackendManager be
         return (compile $ beName be, be, m)) $
    cfBackends cf
  warpServer <- newServer cf
  case maybe80 of
    Nothing -> return ()
    Just sock80 -> do
      let httpsPort = fromMaybe (cfListen cf) (cfHttpsPort cf)
      Log.info "listening on port 80 (HTTP redirect)"
      listen sock80 maxListenQueue
      void . forkIO $ runSettingsSocket settings sock80 (redirect httpsPort)
  -- XXX 2048 is from bindPortTCP from streaming-commons used internally by runTLS.
  -- XXX Since we don't call runTLS, we listen socket here with the same options.
  Log.info $ "proxy listening on port " ++ show (cfListen cf)
  listen sock (max 2048 maxListenQueue)
  warpServer settings sock (sproxy key db oauth2clients backends)

newDataSource :: ConfigFile -> IO (Maybe DB.DataSource)
newDataSource cf =
  case (cfDataFile cf, cfDatabase cf) of
    (Nothing, Just str) -> do
      case cfPgPassFile cf of
        Nothing -> return ()
        Just f -> do
          Log.info $ "pgpassfile: " ++ show f
          setEnv "PGPASSFILE" f
      return . Just $ DB.PostgreSQL str
    (Just f, Nothing) -> return . Just $ DB.File f
    (Nothing, Nothing) -> return Nothing
    _ -> do
      Log.error "only one data source can be used"
      exitFailure

newOAuth2Client :: (Text, OAuth2Conf) -> IO (Text, OAuth2Client)
newOAuth2Client (name, cfg) =
  case HM.lookup name OAuth2.providers of
    Nothing -> do
      Log.error $ "OAuth2 provider " ++ show name ++ " is not supported"
      exitFailure
    Just provider -> do
      Log.info $ "oauth2: adding " ++ show name
      return (name, provider (client_id, client_secret))
  where
    client_id = pack $ oa2ClientId cfg
    client_secret = pack $ oa2ClientSecret cfg

newBackendManager :: BackendConf -> IO Manager
newBackendManager be = do
  openConn <-
    case (beSocket be, bePort be) of
      (Just f, Nothing) -> do
        Log.info $ "backend `" ++ beName be ++ "' on UNIX socket " ++ f
        return $ openUnixSocketConnection f
      (Nothing, Just n) -> do
        let svc = show n
        Log.info $
          "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ svc
        return $ openTCPConnection (beAddress be) svc
      _ -> do
        Log.error "either backend port number or UNIX socket path is required."
        exitFailure
  newManager
    defaultManagerSettings
      { managerRawConnection = return $ \_ _ _ -> openConn
      , managerConnCount = beConnCount be
      , managerResponseTimeout = responseTimeoutMicro (1000000 * beTimeout be)
      }

newServer :: ConfigFile -> IO (Settings -> Socket -> Application -> IO ())
newServer cf
  | cfSsl cf =
    case (cfSslKey cf, cfSslCert cf) of
      (Just k, Just c) ->
        return $ runTLSSocket (tlsSettingsChain c (cfSslCertChain cf) k)
      _ -> do
        Log.error "missings SSL certificate"
        exitFailure
  | otherwise = do
    Log.warn "not using SSL!"
    return runSettingsSocket

openUnixSocketConnection :: FilePath -> IO Connection
openUnixSocketConnection f =
  bracketOnError
    (socket AF_UNIX Stream 0)
    close
    (\s -> do
       connect s (SockAddrUnix f)
       socketConnection s 8192)

openTCPConnection :: String -> String -> IO Connection
openTCPConnection host svc = do
  addr:_ <- getAddrInfo (Just hints) (Just host) (Just svc)
  bracketOnError
    (socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
    close
    (\s -> do
       connect s (addrAddress addr)
       socketConnection s 8192)
  where
    hints = defaultHints {addrFlags = [AI_NUMERICSERV], addrSocketType = Stream}

readConfigFile :: FilePath -> IO ConfigFile
readConfigFile f = do
  r <- decodeFileEither f
  case r of
    Left e -> do
      hPutStrLn stderr $ "FATAL: " ++ f ++ ": " ++ show e
      exitFailure
    Right cf -> return cf