...
Source file
src/net/http/csrf.go
1
2
3
4
5 package http
6
7 import (
8 "errors"
9 "fmt"
10 "net/url"
11 "sync"
12 "sync/atomic"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 type CrossOriginProtection struct {
37 bypass atomic.Pointer[ServeMux]
38 trustedMu sync.RWMutex
39 trusted map[string]bool
40 deny atomic.Pointer[Handler]
41 }
42
43
44 func NewCrossOriginProtection() *CrossOriginProtection {
45 return &CrossOriginProtection{}
46 }
47
48
49
50
51
52
53
54
55
56
57 func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error {
58 u, err := url.Parse(origin)
59 if err != nil {
60 return fmt.Errorf("invalid origin %q: %w", origin, err)
61 }
62 if u.Scheme == "" {
63 return fmt.Errorf("invalid origin %q: scheme is required", origin)
64 }
65 if u.Host == "" {
66 return fmt.Errorf("invalid origin %q: host is required", origin)
67 }
68 if u.Path != "" || u.RawQuery != "" || u.Fragment != "" {
69 return fmt.Errorf("invalid origin %q: path, query, and fragment are not allowed", origin)
70 }
71 c.trustedMu.Lock()
72 defer c.trustedMu.Unlock()
73 if c.trusted == nil {
74 c.trusted = make(map[string]bool)
75 }
76 c.trusted[origin] = true
77 return nil
78 }
79
80 var noopHandler = HandlerFunc(func(w ResponseWriter, r *Request) {})
81
82
83
84
85
86
87 func (c *CrossOriginProtection) AddInsecureBypassPattern(pattern string) {
88 var bypass *ServeMux
89
90
91 for {
92 bypass = c.bypass.Load()
93 if bypass != nil {
94 break
95 }
96 bypass = NewServeMux()
97 if c.bypass.CompareAndSwap(nil, bypass) {
98 break
99 }
100 }
101
102 bypass.Handle(pattern, noopHandler)
103 }
104
105
106
107
108
109
110
111
112 func (c *CrossOriginProtection) SetDenyHandler(h Handler) {
113 if h == nil {
114 c.deny.Store(nil)
115 return
116 }
117 c.deny.Store(&h)
118 }
119
120
121
122 func (c *CrossOriginProtection) Check(req *Request) error {
123 switch req.Method {
124 case "GET", "HEAD", "OPTIONS":
125
126 return nil
127 }
128
129 switch req.Header.Get("Sec-Fetch-Site") {
130 case "":
131
132
133 case "same-origin", "none":
134 return nil
135 default:
136 if c.isRequestExempt(req) {
137 return nil
138 }
139 return errCrossOriginRequest
140 }
141
142 origin := req.Header.Get("Origin")
143 if origin == "" {
144
145
146 return nil
147 }
148
149 if o, err := url.Parse(origin); err == nil && o.Host == req.Host {
150
151
152
153
154
155
156 return nil
157 }
158
159 if c.isRequestExempt(req) {
160 return nil
161 }
162 return errCrossOriginRequestFromOldBrowser
163 }
164
165 var (
166 errCrossOriginRequest = errors.New("cross-origin request detected from Sec-Fetch-Site header")
167 errCrossOriginRequestFromOldBrowser = errors.New("cross-origin request detected, and/or browser is out of date: " +
168 "Sec-Fetch-Site is missing, and Origin does not match Host")
169 )
170
171
172
173 func (c *CrossOriginProtection) isRequestExempt(req *Request) bool {
174 if bypass := c.bypass.Load(); bypass != nil {
175 if _, pattern := bypass.Handler(req); pattern != "" {
176
177 return true
178 }
179 }
180
181 c.trustedMu.RLock()
182 defer c.trustedMu.RUnlock()
183 origin := req.Header.Get("Origin")
184
185 return origin != "" && c.trusted[origin]
186 }
187
188
189
190
191
192
193
194 func (c *CrossOriginProtection) Handler(h Handler) Handler {
195 return HandlerFunc(func(w ResponseWriter, r *Request) {
196 if err := c.Check(r); err != nil {
197 if deny := c.deny.Load(); deny != nil {
198 (*deny).ServeHTTP(w, r)
199 return
200 }
201 Error(w, err.Error(), StatusForbidden)
202 return
203 }
204 h.ServeHTTP(w, r)
205 })
206 }
207
View as plain text