cors.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package handlers
  2. import (
  3. "net/http"
  4. "strconv"
  5. "strings"
  6. )
  7. // CORSOption represents a functional option for configuring the CORS middleware.
  8. type CORSOption func(*cors) error
  9. type cors struct {
  10. h http.Handler
  11. allowedHeaders []string
  12. allowedMethods []string
  13. allowedOrigins []string
  14. allowedOriginValidator OriginValidator
  15. exposedHeaders []string
  16. maxAge int
  17. ignoreOptions bool
  18. allowCredentials bool
  19. }
  20. // OriginValidator takes an origin string and returns whether or not that origin is allowed.
  21. type OriginValidator func(string) bool
  22. var (
  23. defaultCorsMethods = []string{"GET", "HEAD", "POST"}
  24. defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
  25. // (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
  26. )
  27. const (
  28. corsOptionMethod string = "OPTIONS"
  29. corsAllowOriginHeader string = "Access-Control-Allow-Origin"
  30. corsExposeHeadersHeader string = "Access-Control-Expose-Headers"
  31. corsMaxAgeHeader string = "Access-Control-Max-Age"
  32. corsAllowMethodsHeader string = "Access-Control-Allow-Methods"
  33. corsAllowHeadersHeader string = "Access-Control-Allow-Headers"
  34. corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
  35. corsRequestMethodHeader string = "Access-Control-Request-Method"
  36. corsRequestHeadersHeader string = "Access-Control-Request-Headers"
  37. corsOriginHeader string = "Origin"
  38. corsVaryHeader string = "Vary"
  39. corsOriginMatchAll string = "*"
  40. )
  41. func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  42. origin := r.Header.Get(corsOriginHeader)
  43. if !ch.isOriginAllowed(origin) {
  44. ch.h.ServeHTTP(w, r)
  45. return
  46. }
  47. if r.Method == corsOptionMethod {
  48. if ch.ignoreOptions {
  49. ch.h.ServeHTTP(w, r)
  50. return
  51. }
  52. if _, ok := r.Header[corsRequestMethodHeader]; !ok {
  53. w.WriteHeader(http.StatusBadRequest)
  54. return
  55. }
  56. method := r.Header.Get(corsRequestMethodHeader)
  57. if !ch.isMatch(method, ch.allowedMethods) {
  58. w.WriteHeader(http.StatusMethodNotAllowed)
  59. return
  60. }
  61. requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
  62. allowedHeaders := []string{}
  63. for _, v := range requestHeaders {
  64. canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
  65. if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
  66. continue
  67. }
  68. if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
  69. w.WriteHeader(http.StatusForbidden)
  70. return
  71. }
  72. allowedHeaders = append(allowedHeaders, canonicalHeader)
  73. }
  74. if len(allowedHeaders) > 0 {
  75. w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
  76. }
  77. if ch.maxAge > 0 {
  78. w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
  79. }
  80. if !ch.isMatch(method, defaultCorsMethods) {
  81. w.Header().Set(corsAllowMethodsHeader, method)
  82. }
  83. } else {
  84. if len(ch.exposedHeaders) > 0 {
  85. w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
  86. }
  87. }
  88. if ch.allowCredentials {
  89. w.Header().Set(corsAllowCredentialsHeader, "true")
  90. }
  91. if len(ch.allowedOrigins) > 1 {
  92. w.Header().Set(corsVaryHeader, corsOriginHeader)
  93. }
  94. w.Header().Set(corsAllowOriginHeader, origin)
  95. if r.Method == corsOptionMethod {
  96. return
  97. }
  98. ch.h.ServeHTTP(w, r)
  99. }
  100. // CORS provides Cross-Origin Resource Sharing middleware.
  101. // Example:
  102. //
  103. // import (
  104. // "net/http"
  105. //
  106. // "github.com/gorilla/handlers"
  107. // "github.com/gorilla/mux"
  108. // )
  109. //
  110. // func main() {
  111. // r := mux.NewRouter()
  112. // r.HandleFunc("/users", UserEndpoint)
  113. // r.HandleFunc("/projects", ProjectEndpoint)
  114. //
  115. // // Apply the CORS middleware to our top-level router, with the defaults.
  116. // http.ListenAndServe(":8000", handlers.CORS()(r))
  117. // }
  118. //
  119. func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
  120. return func(h http.Handler) http.Handler {
  121. ch := parseCORSOptions(opts...)
  122. ch.h = h
  123. return ch
  124. }
  125. }
  126. func parseCORSOptions(opts ...CORSOption) *cors {
  127. ch := &cors{
  128. allowedMethods: defaultCorsMethods,
  129. allowedHeaders: defaultCorsHeaders,
  130. allowedOrigins: []string{corsOriginMatchAll},
  131. }
  132. for _, option := range opts {
  133. option(ch)
  134. }
  135. return ch
  136. }
  137. //
  138. // Functional options for configuring CORS.
  139. //
  140. // AllowedHeaders adds the provided headers to the list of allowed headers in a
  141. // CORS request.
  142. // This is an append operation so the headers Accept, Accept-Language,
  143. // and Content-Language are always allowed.
  144. // Content-Type must be explicitly declared if accepting Content-Types other than
  145. // application/x-www-form-urlencoded, multipart/form-data, or text/plain.
  146. func AllowedHeaders(headers []string) CORSOption {
  147. return func(ch *cors) error {
  148. for _, v := range headers {
  149. normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
  150. if normalizedHeader == "" {
  151. continue
  152. }
  153. if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
  154. ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
  155. }
  156. }
  157. return nil
  158. }
  159. }
  160. // AllowedMethods can be used to explicitly allow methods in the
  161. // Access-Control-Allow-Methods header.
  162. // This is a replacement operation so you must also
  163. // pass GET, HEAD, and POST if you wish to support those methods.
  164. func AllowedMethods(methods []string) CORSOption {
  165. return func(ch *cors) error {
  166. ch.allowedMethods = []string{}
  167. for _, v := range methods {
  168. normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
  169. if normalizedMethod == "" {
  170. continue
  171. }
  172. if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
  173. ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
  174. }
  175. }
  176. return nil
  177. }
  178. }
  179. // AllowedOrigins sets the allowed origins for CORS requests, as used in the
  180. // 'Allow-Access-Control-Origin' HTTP header.
  181. // Note: Passing in a []string{"*"} will allow any domain.
  182. func AllowedOrigins(origins []string) CORSOption {
  183. return func(ch *cors) error {
  184. for _, v := range origins {
  185. if v == corsOriginMatchAll {
  186. ch.allowedOrigins = []string{corsOriginMatchAll}
  187. return nil
  188. }
  189. }
  190. ch.allowedOrigins = origins
  191. return nil
  192. }
  193. }
  194. // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
  195. // 'Allow-Access-Control-Origin' HTTP header.
  196. func AllowedOriginValidator(fn OriginValidator) CORSOption {
  197. return func(ch *cors) error {
  198. ch.allowedOriginValidator = fn
  199. return nil
  200. }
  201. }
  202. // ExposeHeaders can be used to specify headers that are available
  203. // and will not be stripped out by the user-agent.
  204. func ExposedHeaders(headers []string) CORSOption {
  205. return func(ch *cors) error {
  206. ch.exposedHeaders = []string{}
  207. for _, v := range headers {
  208. normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
  209. if normalizedHeader == "" {
  210. continue
  211. }
  212. if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
  213. ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
  214. }
  215. }
  216. return nil
  217. }
  218. }
  219. // MaxAge determines the maximum age (in seconds) between preflight requests. A
  220. // maximum of 10 minutes is allowed. An age above this value will default to 10
  221. // minutes.
  222. func MaxAge(age int) CORSOption {
  223. return func(ch *cors) error {
  224. // Maximum of 10 minutes.
  225. if age > 600 {
  226. age = 600
  227. }
  228. ch.maxAge = age
  229. return nil
  230. }
  231. }
  232. // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
  233. // passing them through to the next handler. This is useful when your application
  234. // or framework has a pre-existing mechanism for responding to OPTIONS requests.
  235. func IgnoreOptions() CORSOption {
  236. return func(ch *cors) error {
  237. ch.ignoreOptions = true
  238. return nil
  239. }
  240. }
  241. // AllowCredentials can be used to specify that the user agent may pass
  242. // authentication details along with the request.
  243. func AllowCredentials() CORSOption {
  244. return func(ch *cors) error {
  245. ch.allowCredentials = true
  246. return nil
  247. }
  248. }
  249. func (ch *cors) isOriginAllowed(origin string) bool {
  250. if origin == "" {
  251. return false
  252. }
  253. if ch.allowedOriginValidator != nil {
  254. return ch.allowedOriginValidator(origin)
  255. }
  256. for _, allowedOrigin := range ch.allowedOrigins {
  257. if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
  258. return true
  259. }
  260. }
  261. return false
  262. }
  263. func (ch *cors) isMatch(needle string, haystack []string) bool {
  264. for _, v := range haystack {
  265. if v == needle {
  266. return true
  267. }
  268. }
  269. return false
  270. }