diff --git a/docs/preview.rst b/docs/preview.rst index a9c11ef..d2000ce 100644 --- a/docs/preview.rst +++ b/docs/preview.rst @@ -139,3 +139,5 @@ Optional methods .. automethod:: FormPreview.security_hash .. automethod:: FormPreview.failed_hash + +.. automethod:: FormPreview.form_extra_params diff --git a/formtools/preview.py b/formtools/preview.py index 9482dde..f98173d 100644 --- a/formtools/preview.py +++ b/formtools/preview.py @@ -51,7 +51,8 @@ def unused_name(self, name): def preview_get(self, request): "Displays the form" f = self.form(auto_id=self.get_auto_id(), - initial=self.get_initial(request)) + initial=self.get_initial(request), + **self.form_extra_params(request)) return render(request, self.form_template, self.get_context(request, f)) def preview_post(self, request): @@ -61,7 +62,10 @@ def preview_post(self, request): """ # Even if files are not supported in preview, we still initialize files # to give a chance to process_preview to access files content. - f = self.form(data=request.POST, files=request.FILES, auto_id=self.get_auto_id()) + f = self.form(data=request.POST, + files=request.FILES, + auto_id=self.get_auto_id(), + **self.form_extra_params(request)) context = self.get_context(request, f) if f.is_valid(): self.process_preview(request, f, context) @@ -79,7 +83,7 @@ def post_post(self, request): """ Validates the POST data. If valid, calls done(). Else, redisplays form. """ - form = self.form(request.POST, auto_id=self.get_auto_id()) + form = self.form(request.POST, auto_id=self.get_auto_id(), **self.form_extra_params(request)) if form.is_valid(): if not self._check_security_hash( request.POST.get(self.unused_name('hash'), ''), @@ -158,6 +162,14 @@ def failed_hash(self, request): """ return self.preview_post(request) + def form_extra_params(self, request): + """ + Extra parameters to pass to the form constructor. + Returns a dictionary. + By default, returns an empty dictionary. + """ + return {} + # METHODS SUBCLASSES MUST OVERRIDE ######################################## def done(self, request, cleaned_data): diff --git a/tests/forms.py b/tests/forms.py index 85534f5..4c648b3 100644 --- a/tests/forms.py +++ b/tests/forms.py @@ -7,6 +7,10 @@ class TestForm(forms.Form): bool1 = forms.BooleanField(required=False) date1 = forms.DateField(required=False) + def __init__(self, *args, **kwargs): + self.request = kwargs.pop('request', None) + super(TestForm, self).__init__(*args, **kwargs) + class HashTestForm(forms.Form): name = forms.CharField() diff --git a/tests/tests.py b/tests/tests.py index f9a2a35..0346f1e 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -32,6 +32,9 @@ def get_context(self, request, form): def get_initial(self, request): return {'field1': 'Works!'} + def form_extra_params(self, request): + return {'request': request} + def done(self, request, cleaned_data): return http.HttpResponse(success_string) @@ -64,7 +67,7 @@ def test_parse_params_takes_request_object(self): def test_unused_name(self): """ - Verifies name mangling to get uniue field name. + Verifies name mangling to get unique field name. """ self.assertEqual(self.preview.unused_name('field1'), 'field1__')