aboutsummaryrefslogtreecommitdiff
path: root/src/Sproxy/Server.hs
blob: bd2af17d0ddebf4a4bd8e87804090e1dfd5c29c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
module Sproxy.Server (
  server
) where

import Control.Concurrent (forkIO)
import Control.Exception (bracketOnError)
import Control.Monad (void, when)
import Data.ByteString as BS (hGetLine, readFile)
import Data.ByteString.Char8 (pack)
import Data.HashMap.Strict as HM (fromList, lookup, toList)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Word (Word16)
import Data.Yaml (decodeFileEither)
import Network.HTTP.Client (Manager, ManagerSettings(..), defaultManagerSettings, newManager, socketConnection)
import Network.HTTP.Client.Internal (Connection)
import Network.Socket ( Family(AF_INET, AF_UNIX), SockAddr(SockAddrInet, SockAddrUnix),
  SocketOption(ReuseAddr), SocketType(Stream), bind, close, connect, inet_addr,
  listen, maxListenQueue, setSocketOption, socket )
import Network.Wai.Handler.WarpTLS (tlsSettingsChain, runTLSSocket)
import Network.Wai.Handler.Warp (defaultSettings, setHTTP2Disabled, runSettingsSocket)
import System.Entropy (getEntropy)
import System.Environment (setEnv)
import System.Exit (exitFailure)
import System.FilePath.Glob (compile)
import System.IO (IOMode(ReadMode), hIsEOF, hPutStrLn, stderr, withFile)
import System.Posix.User ( GroupEntry(..), UserEntry(..),
  getAllGroupEntries, getRealUserID,
  getUserEntryForName, setGroupID, setGroups, setUserID )

import Sproxy.Application (sproxy, redirect)
import Sproxy.Application.OAuth2.Common (OAuth2Client)
import Sproxy.Config (BackendConf(..), ConfigFile(..), OAuth2Conf(..))
import qualified Sproxy.Application.OAuth2 as OAuth2
import qualified Sproxy.Logging as Log
import qualified Sproxy.Server.DB as DB


server :: FilePath -> IO ()
server configFile = do
  cf <- readConfigFile configFile
  Log.start $ cfLogLevel cf
  Log.debug $ show cf

  sock <- socket AF_INET Stream 0
  setSocketOption sock ReuseAddr 1
  bind sock $ SockAddrInet (fromIntegral $ cfListen cf) 0

  maybe80 <- if fromMaybe (443 == cfListen cf) (cfListen80 cf)
    then do
      sock80 <- socket AF_INET Stream 0
      setSocketOption sock80 ReuseAddr 1
      bind sock80 $ SockAddrInet 80 0
      return (Just sock80)
    else
      return Nothing

  uid <- getRealUserID
  when (0 == uid) $ do
    let user = cfUser cf
    Log.info $ "switching to user " ++ show user
    u <- getUserEntryForName user
    groupIDs <- map groupID . filter (elem user . groupMembers)
             <$> getAllGroupEntries
    setGroups groupIDs
    setGroupID $ userGroupID u
    setUserID $ userID u

  case cfPgPassFile cf of
    Nothing -> return ()
    Just f  -> do
      Log.info $ "pgpassfile: " ++ show f
      setEnv "PGPASSFILE" f

  db <- DB.start (cfHome cf) (newDataSource cf)

  key <- maybe
           (Log.info "using new random key" >> getEntropy 32)
           (\f -> Log.info ("reading key from " ++ f) >> BS.readFile f)
           (cfKey cf)

  case maybe80 of
    Nothing     -> return ()
    Just sock80 -> do
      Log.info "listening on port 80 (HTTP redirect)"
      listen sock80 maxListenQueue
      void . forkIO $ runSettingsSocket defaultSettings sock80 (redirect $ cfListen cf)

  oauth2clients <- HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf))

  backends <-
    mapM (\be -> do
      m <- newBackendManager be
      return (compile $ beName be, be, m)
    ) $ cfBackends cf

  let
    settings =
      (if cfHTTP2 cf then id else setHTTP2Disabled)
      defaultSettings

  -- XXX 2048 is from bindPortTCP from streaming-commons used internally by runTLS.
  -- XXX Since we don't call runTLS, we listen socket here with the same options.
  Log.info $ "listening on port " ++ show (cfListen cf) ++ " (HTTPS)"
  listen sock (max 2048 maxListenQueue)
  runTLSSocket
    (tlsSettingsChain (cfSslCert cf) (cfSslCertChain cf) (cfSslKey cf))
    settings
    sock
    (sproxy key db oauth2clients backends)


newDataSource :: ConfigFile -> Maybe DB.DataSource
newDataSource cf =
  case cfDatabase cf of
    Just str -> Just $ DB.PostgreSQL str
    Nothing -> Nothing


newOAuth2Client :: (Text, OAuth2Conf) -> IO (Text, OAuth2Client)
newOAuth2Client (name, cfg) =
  case HM.lookup name OAuth2.providers of
    Nothing -> do Log.error $ "OAuth2 provider " ++ show name ++ " is not supported"
                  exitFailure
    Just provider -> do
      Log.info $ "oauth2: adding " ++ show name
      client_secret <- withFile secret_file ReadMode $ \h -> do
        empty <- hIsEOF h
        if empty then do
          Log.error $ "oauth2: empty secret file for "
                    ++ show name ++ ": " ++ show secret_file
          return $ pack ""
        else BS.hGetLine h
      return (name, provider (pack client_id, client_secret))
  where client_id = oa2ClientId cfg
        secret_file = oa2ClientSecret cfg


newBackendManager :: BackendConf -> IO Manager
newBackendManager be = do
  openConn <-
    case (beSocket be, bePort be) of
      (Just f, Nothing) -> do
        Log.info $ "backend `" ++ beName be ++ "' on UNIX socket " ++ f
        return $ openUnixSocketConnection f

      (Nothing, Just n) -> do
        Log.info $ "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ show n
        return $ openTCPConnection (beAddress be) n

      _ -> do
            Log.error "either backend port number or UNIX socket path is required."
            exitFailure

  newManager defaultManagerSettings {
    managerRawConnection = return $ \_ _ _ -> openConn
  , managerConnCount = beConnCount be
  }


openUnixSocketConnection :: FilePath -> IO Connection
openUnixSocketConnection f =
  bracketOnError
   (socket AF_UNIX Stream 0)
   close
   (\s -> do
      connect s (SockAddrUnix f)
      socketConnection s 8192)


openTCPConnection :: String -> Word16 -> IO Connection
openTCPConnection addr port =
  bracketOnError
   (socket AF_INET Stream 0)
   close
   (\s -> do
      a <- inet_addr addr
      connect s (SockAddrInet (fromIntegral port) a)
      socketConnection s 8192)


readConfigFile :: FilePath -> IO ConfigFile
readConfigFile f = do
  r <- decodeFileEither f
  case r of
    Left e -> do
      hPutStrLn stderr $ "FATAL: " ++ f ++ ": " ++ show e
      exitFailure
    Right cf -> return cf