summaryrefslogtreecommitdiff
path: root/vendor/github.com/emicklei/go-restful/container.go
blob: 657d5b6ddd00a630e8a3cb5ab5e0602475d26d8c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
package restful

// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.

import (
	"bytes"
	"errors"
	"fmt"
	"net/http"
	"os"
	"runtime"
	"strings"
	"sync"

	"github.com/emicklei/go-restful/log"
)

// Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
// The requests are further dispatched to routes of WebServices using a RouteSelector
type Container struct {
	webServicesLock        sync.RWMutex
	webServices            []*WebService
	ServeMux               *http.ServeMux
	isRegisteredOnRoot     bool
	containerFilters       []FilterFunction
	doNotRecover           bool // default is true
	recoverHandleFunc      RecoverHandleFunction
	serviceErrorHandleFunc ServiceErrorHandleFunction
	router                 RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
	contentEncodingEnabled bool          // default is false
}

// NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
func NewContainer() *Container {
	return &Container{
		webServices:            []*WebService{},
		ServeMux:               http.NewServeMux(),
		isRegisteredOnRoot:     false,
		containerFilters:       []FilterFunction{},
		doNotRecover:           true,
		recoverHandleFunc:      logStackOnRecover,
		serviceErrorHandleFunc: writeServiceError,
		router:                 CurlyRouter{},
		contentEncodingEnabled: false}
}

// RecoverHandleFunction declares functions that can be used to handle a panic situation.
// The first argument is what recover() returns. The second must be used to communicate an error response.
type RecoverHandleFunction func(interface{}, http.ResponseWriter)

// RecoverHandler changes the default function (logStackOnRecover) to be called
// when a panic is detected. DoNotRecover must be have its default value (=false).
func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
	c.recoverHandleFunc = handler
}

// ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
// The first argument is the service error, the second is the request that resulted in the error and
// the third must be used to communicate an error response.
type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)

// ServiceErrorHandler changes the default function (writeServiceError) to be called
// when a ServiceError is detected.
func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
	c.serviceErrorHandleFunc = handler
}

// DoNotRecover controls whether panics will be caught to return HTTP 500.
// If set to true, Route functions are responsible for handling any error situation.
// Default value is true.
func (c *Container) DoNotRecover(doNot bool) {
	c.doNotRecover = doNot
}

// Router changes the default Router (currently CurlyRouter)
func (c *Container) Router(aRouter RouteSelector) {
	c.router = aRouter
}

// EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
func (c *Container) EnableContentEncoding(enabled bool) {
	c.contentEncodingEnabled = enabled
}

// Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
func (c *Container) Add(service *WebService) *Container {
	c.webServicesLock.Lock()
	defer c.webServicesLock.Unlock()

	// if rootPath was not set then lazy initialize it
	if len(service.rootPath) == 0 {
		service.Path("/")
	}

	// cannot have duplicate root paths
	for _, each := range c.webServices {
		if each.RootPath() == service.RootPath() {
			log.Printf("[restful] WebService with duplicate root path detected:['%v']", each)
			os.Exit(1)
		}
	}

	// If not registered on root then add specific mapping
	if !c.isRegisteredOnRoot {
		c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
	}
	c.webServices = append(c.webServices, service)
	return c
}

// addHandler may set a new HandleFunc for the serveMux
// this function must run inside the critical region protected by the webServicesLock.
// returns true if the function was registered on root ("/")
func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
	pattern := fixedPrefixPath(service.RootPath())
	// check if root path registration is needed
	if "/" == pattern || "" == pattern {
		serveMux.HandleFunc("/", c.dispatch)
		return true
	}
	// detect if registration already exists
	alreadyMapped := false
	for _, each := range c.webServices {
		if each.RootPath() == service.RootPath() {
			alreadyMapped = true
			break
		}
	}
	if !alreadyMapped {
		serveMux.HandleFunc(pattern, c.dispatch)
		if !strings.HasSuffix(pattern, "/") {
			serveMux.HandleFunc(pattern+"/", c.dispatch)
		}
	}
	return false
}

func (c *Container) Remove(ws *WebService) error {
	if c.ServeMux == http.DefaultServeMux {
		errMsg := fmt.Sprintf("[restful] cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
		log.Printf(errMsg)
		return errors.New(errMsg)
	}
	c.webServicesLock.Lock()
	defer c.webServicesLock.Unlock()
	// build a new ServeMux and re-register all WebServices
	newServeMux := http.NewServeMux()
	newServices := []*WebService{}
	newIsRegisteredOnRoot := false
	for _, each := range c.webServices {
		if each.rootPath != ws.rootPath {
			// If not registered on root then add specific mapping
			if !newIsRegisteredOnRoot {
				newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
			}
			newServices = append(newServices, each)
		}
	}
	c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
	return nil
}

// logStackOnRecover is the default RecoverHandleFunction and is called
// when DoNotRecover is false and the recoverHandleFunc is not set for the container.
// Default implementation logs the stacktrace and writes the stacktrace on the response.
// This may be a security issue as it exposes sourcecode information.
func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
	var buffer bytes.Buffer
	buffer.WriteString(fmt.Sprintf("[restful] recover from panic situation: - %v\r\n", panicReason))
	for i := 2; ; i += 1 {
		_, file, line, ok := runtime.Caller(i)
		if !ok {
			break
		}
		buffer.WriteString(fmt.Sprintf("    %s:%d\r\n", file, line))
	}
	log.Print(buffer.String())
	httpWriter.WriteHeader(http.StatusInternalServerError)
	httpWriter.Write(buffer.Bytes())
}

// writeServiceError is the default ServiceErrorHandleFunction and is called
// when a ServiceError is returned during route selection. Default implementation
// calls resp.WriteErrorString(err.Code, err.Message)
func writeServiceError(err ServiceError, req *Request, resp *Response) {
	resp.WriteErrorString(err.Code, err.Message)
}

// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
	if httpWriter == nil {
		panic("httpWriter cannot be nil")
	}
	if httpRequest == nil {
		panic("httpRequest cannot be nil")
	}
	c.dispatch(httpWriter, httpRequest)
}

// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
	writer := httpWriter

	// CompressingResponseWriter should be closed after all operations are done
	defer func() {
		if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
			compressWriter.Close()
		}
	}()

	// Instal panic recovery unless told otherwise
	if !c.doNotRecover { // catch all for 500 response
		defer func() {
			if r := recover(); r != nil {
				c.recoverHandleFunc(r, writer)
				return
			}
		}()
	}

	// Detect if compression is needed
	// assume without compression, test for override
	if c.contentEncodingEnabled {
		doCompress, encoding := wantsCompressedResponse(httpRequest)
		if doCompress {
			var err error
			writer, err = NewCompressingResponseWriter(httpWriter, encoding)
			if err != nil {
				log.Print("[restful] unable to install compressor: ", err)
				httpWriter.WriteHeader(http.StatusInternalServerError)
				return
			}
		}
	}
	// Find best match Route ; err is non nil if no match was found
	var webService *WebService
	var route *Route
	var err error
	func() {
		c.webServicesLock.RLock()
		defer c.webServicesLock.RUnlock()
		webService, route, err = c.router.SelectRoute(
			c.webServices,
			httpRequest)
	}()
	if err != nil {
		// a non-200 response has already been written
		// run container filters anyway ; they should not touch the response...
		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
			switch err.(type) {
			case ServiceError:
				ser := err.(ServiceError)
				c.serviceErrorHandleFunc(ser, req, resp)
			}
			// TODO
		}}
		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
		return
	}
	wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest)
	// pass through filters (if any)
	if len(c.containerFilters)+len(webService.filters)+len(route.Filters) > 0 {
		// compose filter chain
		allFilters := []FilterFunction{}
		allFilters = append(allFilters, c.containerFilters...)
		allFilters = append(allFilters, webService.filters...)
		allFilters = append(allFilters, route.Filters...)
		chain := FilterChain{Filters: allFilters, Target: func(req *Request, resp *Response) {
			// handle request by route after passing all filters
			route.Function(wrappedRequest, wrappedResponse)
		}}
		chain.ProcessFilter(wrappedRequest, wrappedResponse)
	} else {
		// no filters, handle request by route
		route.Function(wrappedRequest, wrappedResponse)
	}
}

// fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
func fixedPrefixPath(pathspec string) string {
	varBegin := strings.Index(pathspec, "{")
	if -1 == varBegin {
		return pathspec
	}
	return pathspec[:varBegin]
}

// ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
func (c *Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) {
	c.ServeMux.ServeHTTP(httpwriter, httpRequest)
}

// Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
func (c *Container) Handle(pattern string, handler http.Handler) {
	c.ServeMux.Handle(pattern, handler)
}

// HandleWithFilter registers the handler for the given pattern.
// Container's filter chain is applied for handler.
// If a handler already exists for pattern, HandleWithFilter panics.
func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
	f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
		if len(c.containerFilters) == 0 {
			handler.ServeHTTP(httpResponse, httpRequest)
			return
		}

		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
			handler.ServeHTTP(httpResponse, httpRequest)
		}}
		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
	}

	c.Handle(pattern, http.HandlerFunc(f))
}

// Filter appends a container FilterFunction. These are called before dispatching
// a http.Request to a WebService from the container
func (c *Container) Filter(filter FilterFunction) {
	c.containerFilters = append(c.containerFilters, filter)
}

// RegisteredWebServices returns the collections of added WebServices
func (c *Container) RegisteredWebServices() []*WebService {
	c.webServicesLock.RLock()
	defer c.webServicesLock.RUnlock()
	result := make([]*WebService, len(c.webServices))
	for ix := range c.webServices {
		result[ix] = c.webServices[ix]
	}
	return result
}

// computeAllowedMethods returns a list of HTTP methods that are valid for a Request
func (c *Container) computeAllowedMethods(req *Request) []string {
	// Go through all RegisteredWebServices() and all its Routes to collect the options
	methods := []string{}
	requestPath := req.Request.URL.Path
	for _, ws := range c.RegisteredWebServices() {
		matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
		if matches != nil {
			finalMatch := matches[len(matches)-1]
			for _, rt := range ws.Routes() {
				matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
				if matches != nil {
					lastMatch := matches[len(matches)-1]
					if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
						methods = append(methods, rt.Method)
					}
				}
			}
		}
	}
	// methods = append(methods, "OPTIONS")  not sure about this
	return methods
}

// newBasicRequestResponse creates a pair of Request,Response from its http versions.
// It is basic because no parameter or (produces) content-type information is given.
func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
	resp := NewResponse(httpWriter)
	resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
	return NewRequest(httpRequest), resp
}