cors.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package handlers
  2. import (
  3. "net/http"
  4. "strconv"
  5. "strings"
  6. )
  7. // CORSOption represents a functional option for configuring the CORS middleware.
  8. type CORSOption func(*cors) error
  9. type cors struct {
  10. h http.Handler
  11. allowedHeaders []string
  12. allowedMethods []string
  13. allowedOrigins []string
  14. allowedOriginValidator OriginValidator
  15. exposedHeaders []string
  16. maxAge int
  17. ignoreOptions bool
  18. allowCredentials bool
  19. }
  20. // OriginValidator takes an origin string and returns whether or not that origin is allowed.
  21. type OriginValidator func(string) bool
  22. var (
  23. defaultCorsMethods = []string{"GET", "HEAD", "POST"}
  24. defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
  25. // (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
  26. )
  27. const (
  28. corsOptionMethod string = "OPTIONS"
  29. corsAllowOriginHeader string = "Access-Control-Allow-Origin"
  30. corsExposeHeadersHeader string = "Access-Control-Expose-Headers"
  31. corsMaxAgeHeader string = "Access-Control-Max-Age"
  32. corsAllowMethodsHeader string = "Access-Control-Allow-Methods"
  33. corsAllowHeadersHeader string = "Access-Control-Allow-Headers"
  34. corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
  35. corsRequestMethodHeader string = "Access-Control-Request-Method"
  36. corsRequestHeadersHeader string = "Access-Control-Request-Headers"
  37. corsOriginHeader string = "Origin"
  38. corsVaryHeader string = "Vary"
  39. corsOriginMatchAll string = "*"
  40. )
  41. func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  42. origin := r.Header.Get(corsOriginHeader)
  43. if !ch.isOriginAllowed(origin) {
  44. ch.h.ServeHTTP(w, r)
  45. return
  46. }
  47. if r.Method == corsOptionMethod {
  48. if ch.ignoreOptions {
  49. ch.h.ServeHTTP(w, r)
  50. return
  51. }
  52. if _, ok := r.Header[corsRequestMethodHeader]; !ok {
  53. w.WriteHeader(http.StatusBadRequest)
  54. return
  55. }
  56. method := r.Header.Get(corsRequestMethodHeader)
  57. if !ch.isMatch(method, ch.allowedMethods) {
  58. w.WriteHeader(http.StatusMethodNotAllowed)
  59. return
  60. }
  61. requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
  62. allowedHeaders := []string{}
  63. for _, v := range requestHeaders {
  64. canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
  65. if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
  66. continue
  67. }
  68. if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
  69. w.WriteHeader(http.StatusForbidden)
  70. return
  71. }
  72. allowedHeaders = append(allowedHeaders, canonicalHeader)
  73. }
  74. if len(allowedHeaders) > 0 {
  75. w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
  76. }
  77. if ch.maxAge > 0 {
  78. w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
  79. }
  80. if !ch.isMatch(method, defaultCorsMethods) {
  81. w.Header().Set(corsAllowMethodsHeader, method)
  82. }
  83. } else {
  84. if len(ch.exposedHeaders) > 0 {
  85. w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
  86. }
  87. }
  88. if ch.allowCredentials {
  89. w.Header().Set(corsAllowCredentialsHeader, "true")
  90. }
  91. if len(ch.allowedOrigins) > 1 {
  92. w.Header().Set(corsVaryHeader, corsOriginHeader)
  93. }
  94. returnOrigin := origin
  95. for _, o := range ch.allowedOrigins {
  96. // A configuration of * is different than explicitly setting an allowed
  97. // origin. Returning arbitrary origin headers an an access control allow
  98. // origin header is unsafe and is not required by any use case.
  99. if o == corsOriginMatchAll {
  100. returnOrigin = "*"
  101. break
  102. }
  103. }
  104. w.Header().Set(corsAllowOriginHeader, returnOrigin)
  105. if r.Method == corsOptionMethod {
  106. return
  107. }
  108. ch.h.ServeHTTP(w, r)
  109. }
  110. // CORS provides Cross-Origin Resource Sharing middleware.
  111. // Example:
  112. //
  113. // import (
  114. // "net/http"
  115. //
  116. // "github.com/gorilla/handlers"
  117. // "github.com/gorilla/mux"
  118. // )
  119. //
  120. // func main() {
  121. // r := mux.NewRouter()
  122. // r.HandleFunc("/users", UserEndpoint)
  123. // r.HandleFunc("/projects", ProjectEndpoint)
  124. //
  125. // // Apply the CORS middleware to our top-level router, with the defaults.
  126. // http.ListenAndServe(":8000", handlers.CORS()(r))
  127. // }
  128. //
  129. func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
  130. return func(h http.Handler) http.Handler {
  131. ch := parseCORSOptions(opts...)
  132. ch.h = h
  133. return ch
  134. }
  135. }
  136. func parseCORSOptions(opts ...CORSOption) *cors {
  137. ch := &cors{
  138. allowedMethods: defaultCorsMethods,
  139. allowedHeaders: defaultCorsHeaders,
  140. allowedOrigins: []string{corsOriginMatchAll},
  141. }
  142. for _, option := range opts {
  143. option(ch)
  144. }
  145. return ch
  146. }
  147. //
  148. // Functional options for configuring CORS.
  149. //
  150. // AllowedHeaders adds the provided headers to the list of allowed headers in a
  151. // CORS request.
  152. // This is an append operation so the headers Accept, Accept-Language,
  153. // and Content-Language are always allowed.
  154. // Content-Type must be explicitly declared if accepting Content-Types other than
  155. // application/x-www-form-urlencoded, multipart/form-data, or text/plain.
  156. func AllowedHeaders(headers []string) CORSOption {
  157. return func(ch *cors) error {
  158. for _, v := range headers {
  159. normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
  160. if normalizedHeader == "" {
  161. continue
  162. }
  163. if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
  164. ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
  165. }
  166. }
  167. return nil
  168. }
  169. }
  170. // AllowedMethods can be used to explicitly allow methods in the
  171. // Access-Control-Allow-Methods header.
  172. // This is a replacement operation so you must also
  173. // pass GET, HEAD, and POST if you wish to support those methods.
  174. func AllowedMethods(methods []string) CORSOption {
  175. return func(ch *cors) error {
  176. ch.allowedMethods = []string{}
  177. for _, v := range methods {
  178. normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
  179. if normalizedMethod == "" {
  180. continue
  181. }
  182. if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
  183. ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
  184. }
  185. }
  186. return nil
  187. }
  188. }
  189. // AllowedOrigins sets the allowed origins for CORS requests, as used in the
  190. // 'Allow-Access-Control-Origin' HTTP header.
  191. // Note: Passing in a []string{"*"} will allow any domain.
  192. func AllowedOrigins(origins []string) CORSOption {
  193. return func(ch *cors) error {
  194. for _, v := range origins {
  195. if v == corsOriginMatchAll {
  196. ch.allowedOrigins = []string{corsOriginMatchAll}
  197. return nil
  198. }
  199. }
  200. ch.allowedOrigins = origins
  201. return nil
  202. }
  203. }
  204. // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
  205. // 'Allow-Access-Control-Origin' HTTP header.
  206. func AllowedOriginValidator(fn OriginValidator) CORSOption {
  207. return func(ch *cors) error {
  208. ch.allowedOriginValidator = fn
  209. return nil
  210. }
  211. }
  212. // ExposeHeaders can be used to specify headers that are available
  213. // and will not be stripped out by the user-agent.
  214. func ExposedHeaders(headers []string) CORSOption {
  215. return func(ch *cors) error {
  216. ch.exposedHeaders = []string{}
  217. for _, v := range headers {
  218. normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
  219. if normalizedHeader == "" {
  220. continue
  221. }
  222. if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
  223. ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
  224. }
  225. }
  226. return nil
  227. }
  228. }
  229. // MaxAge determines the maximum age (in seconds) between preflight requests. A
  230. // maximum of 10 minutes is allowed. An age above this value will default to 10
  231. // minutes.
  232. func MaxAge(age int) CORSOption {
  233. return func(ch *cors) error {
  234. // Maximum of 10 minutes.
  235. if age > 600 {
  236. age = 600
  237. }
  238. ch.maxAge = age
  239. return nil
  240. }
  241. }
  242. // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
  243. // passing them through to the next handler. This is useful when your application
  244. // or framework has a pre-existing mechanism for responding to OPTIONS requests.
  245. func IgnoreOptions() CORSOption {
  246. return func(ch *cors) error {
  247. ch.ignoreOptions = true
  248. return nil
  249. }
  250. }
  251. // AllowCredentials can be used to specify that the user agent may pass
  252. // authentication details along with the request.
  253. func AllowCredentials() CORSOption {
  254. return func(ch *cors) error {
  255. ch.allowCredentials = true
  256. return nil
  257. }
  258. }
  259. func (ch *cors) isOriginAllowed(origin string) bool {
  260. if origin == "" {
  261. return false
  262. }
  263. if ch.allowedOriginValidator != nil {
  264. return ch.allowedOriginValidator(origin)
  265. }
  266. for _, allowedOrigin := range ch.allowedOrigins {
  267. if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
  268. return true
  269. }
  270. }
  271. return false
  272. }
  273. func (ch *cors) isMatch(needle string, haystack []string) bool {
  274. for _, v := range haystack {
  275. if v == needle {
  276. return true
  277. }
  278. }
  279. return false
  280. }