Skip to content

Commit 4cf08ae

Browse files
authored
Thiagodeev/Fix SNIP-12 JSON Marshaling for TypedData (#789)
* refactor: update benchmarks in typedData_test.go - Moved fileNames to a global variable for better accessibility across tests. - Updated BenchmarkGetMessageHash to utilize the new fileNames variable. - Removed unused BMockTypedData helper - Added BenchmarkUnmarshalJSON to test the performance of the UnmarshalJSON function with various test data files. * test: add TestMarshalJSON to validate JSON marshaling of TypedData * refactor: add JSON tags to TypedData fields for improved JSON serialization * feat: implement MarshalJSON for TypeDefinition to enable JSON serialization * refactor: enhance Domain struct for improved JSON marshaling - Added custom logic in UnmarshalJSON to handle both `chainId` and `chain_id` fields. - Introduced flags to track the type of `chainId` for accurate marshaling. - Implemented MarshalJSON to ensure `chainId` is serialized correctly based on its original format. * refactor: enhance Domain struct for JSON marshaling with revision support - Added logic in UnmarshalJSON to handle the `revision` field, including a flag to track its type. - Updated MarshalJSON to ensure the `Revision` field is serialized correctly based on its original format. - Improved comments for clarity on the purpose of marshaling logic for both `chainId` and `revision` fields. * refactor: improve handling of chainId and chain_id in EncodeData function - Updated logic to accommodate both `chainId` and `chain_id` fields in the domainMap. - Added checks to copy values between the two fields as a workaround for compatibility. - Enhanced comments for better understanding of the changes made. * chore: update changelog with the changes * refactor: improve error messages in UnmarshalJSON for Domain struct - Updated error messages in UnmarshalJSON to specify 'JSON field' instead of 'struct' for clarity * refactor: simplify error handling in UnmarshalJSON for Domain struct - Removed redundant error check for 'chainId' retrieval in UnmarshalJSON. * refactor: enhance error handling in MarshalJSON for Domain struct - Improved error message for 'chainId' parsing to provide clearer context on failure.
1 parent ff34dcd commit 4cf08ae

File tree

3 files changed

+185
-69
lines changed

3 files changed

+185
-69
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2525
- The internal `tryUnwrapToRPCErr` func of the `rpc` pkg was renamed to `UnwrapToRPCErr` and moved to the new package.
2626
- The `Err` function now have a specific case for the `InternalError` code.
2727

28+
### Fixed
29+
- The `typedData.TypedData` was not being marshaled exactly as it is in the original JSON. Now, the original JSON is preserved,
30+
so the output of `TypedData.MarshalJSON()` is exactly as the original JSON.
31+
2832
### Dev updates
2933
- New `internal/tests/jsonrpc_spy.go` file containing a `Spy` type for spying JSON-RPC calls in tests. The
3034
old `rpc/spy_test.go` file was removed.
3135
- New `mocks/mock_client.go` file containing a mock of the `client.Client` type (`client.ClientI` interface).
36+
- New benchmarks and tests for the `typedData` pkg.
3237

3338
## [0.15.0](https://github.com/NethermindEth/starknet.go/releases/tag/v0.15.0) - 2025-09-03
3439
### Changed

typedData/typedData.go

Lines changed: 122 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,23 @@ import (
1919
var typeNameRegexp = regexp.MustCompile(`[^\(\),\s]+`)
2020

2121
type TypedData struct {
22-
Types map[string]TypeDefinition
23-
PrimaryType string
24-
Domain Domain
25-
Message map[string]any
26-
Revision *revision
22+
Types map[string]TypeDefinition `json:"types"`
23+
PrimaryType string `json:"primaryType"`
24+
Domain Domain `json:"domain"`
25+
Message map[string]any `json:"message"`
26+
Revision *revision `json:"-"`
2727
}
2828

2929
type Domain struct {
3030
Name string `json:"name"`
3131
Version string `json:"version"`
3232
ChainId string `json:"chainId"`
3333
Revision uint8 `json:"revision,omitempty"`
34+
35+
// Flags to deal with edge cases, used to marshal the `Domain` exactly as it is in the original JSON.
36+
hasStringChainId bool `json:"-"`
37+
hasOldChainIdName bool `json:"-"`
38+
HasStringRevision bool `json:"-"`
3439
}
3540

3641
type TypeDefinition struct {
@@ -421,8 +426,16 @@ func EncodeData(typeDef *TypeDefinition, td *TypedData, context ...string) (enc
421426
return enc, err
422427
}
423428

429+
// The ChainId can be either `chainId` or `chain_id`.
424430
// ref: https://community.starknet.io/t/signing-transactions-and-off-chain-messages/66
425-
domainMap["chain_id"] = domainMap["chainId"]
431+
// We find the one that contains the value and we copy the value to the other field.
432+
// It's an workaround to handle both cases.
433+
434+
if domainMap["chainId"] != nil {
435+
domainMap["chain_id"] = domainMap["chainId"]
436+
} else {
437+
domainMap["chainId"] = domainMap["chain_id"]
438+
}
426439

427440
return encodeData(typeDef, td, domainMap, false, context...)
428441
}
@@ -886,7 +899,7 @@ func (domain *Domain) UnmarshalJSON(data []byte) error {
886899
getField := func(fieldName string) (string, error) {
887900
value, ok := dec[fieldName]
888901
if !ok {
889-
return "", fmt.Errorf("error getting the value of '%s' from 'domain' struct", fieldName)
902+
return "", fmt.Errorf("error getting the value of '%s' from 'domain' JSON field", fieldName)
890903
}
891904

892905
return fmt.Sprintf("%v", value), nil
@@ -902,34 +915,125 @@ func (domain *Domain) UnmarshalJSON(data []byte) error {
902915
return err
903916
}
904917

905-
revision, err := getField("revision")
906-
if err != nil {
907-
revision = "0"
918+
// Custom logic to handle the `revision` field,
919+
// used to marshal the Revision exactly as it is in the original JSON.
920+
rawRevision, ok := dec["revision"]
921+
if !ok {
922+
rawRevision = "0"
908923
}
909-
numRevision, err := strconv.ParseUint(revision, 10, 8)
910-
if err != nil {
911-
return err
924+
925+
var numRevision uint64
926+
switch revision := rawRevision.(type) {
927+
case string:
928+
domain.HasStringRevision = true
929+
numRevision, err = strconv.ParseUint(revision, 10, 8)
930+
if err != nil {
931+
return err
932+
}
933+
case float64:
934+
domain.HasStringRevision = false
935+
numRevision = uint64(revision)
912936
}
913937

914-
chainId, err := getField("chainId")
915-
if err != nil {
938+
// Custom logic to handle the `chainId` field,
939+
// used to marshal the ChainId exactly as it is in the original JSON.
940+
rawChainId, ok := dec["chainId"]
941+
if !ok {
942+
err = errors.New("error getting the value of 'chainId' from 'domain' JSON field")
916943
if numRevision == 1 {
917944
return err
918945
}
919-
var err2 error
946+
947+
// `chain_id` was also used in the past, so we check for it if the `chainId` field is not found
920948
// ref: https://community.starknet.io/t/signing-transactions-and-off-chain-messages/66
921-
chainId, err2 = getField("chain_id")
922-
if err2 != nil {
949+
rawChainId, ok = dec["chain_id"]
950+
if !ok {
951+
err2 := errors.New("error getting the value of 'chain_id' from 'domain' JSON field")
952+
923953
return fmt.Errorf("%w: %w", err, err2)
924954
}
955+
domain.hasOldChainIdName = true
925956
}
926957

958+
switch rawChainId.(type) {
959+
case string:
960+
domain.hasStringChainId = true
961+
case float64:
962+
domain.hasStringChainId = false
963+
}
964+
chainId := fmt.Sprintf("%v", rawChainId)
965+
966+
// Final step
927967
*domain = Domain{
928968
Name: name,
929969
Version: version,
930970
ChainId: chainId,
931971
Revision: uint8(numRevision),
972+
973+
hasStringChainId: domain.hasStringChainId,
974+
hasOldChainIdName: domain.hasOldChainIdName,
975+
HasStringRevision: domain.HasStringRevision,
932976
}
933977

934978
return nil
935979
}
980+
981+
// MarshalJSON implements the json.Marshaler interface for Domain.
982+
// Some logic was added to marshal the `Domain` exactly as it is in the original JSON.
983+
func (domain Domain) MarshalJSON() ([]byte, error) {
984+
var chainId any
985+
var revision any
986+
var err error
987+
988+
// E.g: if it's `1`, we will marshal it as `1`, not `"1"`.
989+
if domain.hasStringChainId {
990+
chainId = domain.ChainId
991+
} else {
992+
chainId, err = strconv.Atoi(domain.ChainId)
993+
if err != nil {
994+
return nil, fmt.Errorf("cannot parse 'chain_id' value: %w", err)
995+
}
996+
}
997+
998+
if domain.Revision == 0 {
999+
revision = nil
1000+
} else {
1001+
if domain.HasStringRevision {
1002+
revision = strconv.FormatUint(uint64(domain.Revision), 10)
1003+
} else {
1004+
revision = domain.Revision
1005+
}
1006+
}
1007+
1008+
// The purpose here is to marshal the `Domain` exactly as it is in the original JSON.
1009+
// This is achieved by having two chainId fields, one for the old name and one for the new name,
1010+
// and using the `omitempty` tag to only include one of them, the one that is the same as the original JSON.
1011+
// Similar for the `Revision` field.
1012+
var temp struct {
1013+
Name string `json:"name"`
1014+
Version string `json:"version"`
1015+
ChainIdOld any `json:"chain_id,omitempty"` // old chainId json name
1016+
ChainIdNew any `json:"chainId,omitempty"` // new chainId json name
1017+
Revision any `json:"revision,omitempty"`
1018+
}
1019+
temp.Name = domain.Name
1020+
temp.Version = domain.Version
1021+
temp.Revision = revision
1022+
1023+
if domain.hasOldChainIdName {
1024+
temp.ChainIdOld = chainId
1025+
1026+
return json.Marshal(temp)
1027+
}
1028+
1029+
temp.ChainIdNew = chainId
1030+
1031+
return json.Marshal(temp)
1032+
}
1033+
1034+
// MarshalJSON implements the json.Marshaler interface for TypeDefinition
1035+
//
1036+
//nolint:gocritic // json.Marshaler interface requires a value receiver
1037+
func (td TypeDefinition) MarshalJSON() ([]byte, error) {
1038+
return json.Marshal(td.Parameters)
1039+
}

typedData/typedData_test.go

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,9 @@ import (
1010
"github.com/stretchr/testify/require"
1111
)
1212

13-
var typedDataExamples = make(map[string]TypedData)
14-
15-
// TestMain initialises test data by loading TypedData examples from JSON files.
16-
// It reads multiple test files and stores them in the typedDataExamples map
17-
// before running the tests.
18-
//
19-
// Parameters:
20-
// - m: The testing.M object that provides the test runner
21-
//
22-
// Returns:
23-
// - None (calls os.Exit directly)
24-
func TestMain(m *testing.M) {
25-
fileNames := []string{
13+
var (
14+
typedDataExamples = make(map[string]TypedData)
15+
fileNames = []string{
2616
"baseExample",
2717
"example_array",
2818
"example_baseTypes",
@@ -34,7 +24,18 @@ func TestMain(m *testing.M) {
3424
"allInOne",
3525
"example_enumNested",
3626
}
27+
)
3728

29+
// TestMain initialises test data by loading TypedData examples from JSON files.
30+
// It reads multiple test files and stores them in the typedDataExamples map
31+
// before running the tests.
32+
//
33+
// Parameters:
34+
// - m: The testing.M object that provides the test runner
35+
//
36+
// Returns:
37+
// - None (calls os.Exit directly)
38+
func TestMain(m *testing.M) {
3839
for _, fileName := range fileNames {
3940
var ttd TypedData
4041
content, err := os.ReadFile(fmt.Sprintf("./testData/%s.json", fileName))
@@ -52,23 +53,52 @@ func TestMain(m *testing.M) {
5253
os.Exit(m.Run())
5354
}
5455

55-
// BMockTypedData is a helper function for benchmarks that loads a base example
56-
// TypedData from a JSON file.
57-
//
58-
// Parameters:
59-
// - b: The testing.B object used for benchmarking
60-
//
61-
// Returns:
62-
// - ttd: A TypedData instance loaded from the base example file
63-
func BMockTypedData(b *testing.B) (ttd TypedData) {
64-
b.Helper()
65-
content, err := os.ReadFile("./testData/baseExample.json")
66-
require.NoError(b, err)
56+
// BenchmarkUnmarshalJSON is a benchmark function for testing the TypedData.UnmarshalJSON function.
57+
func BenchmarkUnmarshalJSON(b *testing.B) {
58+
for _, fileName := range fileNames {
59+
rawData, err := os.ReadFile(fmt.Sprintf("./testData/%s.json", fileName))
60+
require.NoError(b, err)
6761

68-
err = json.Unmarshal(content, &ttd)
69-
require.NoError(b, err)
62+
b.Run(fileName, func(b *testing.B) {
63+
for b.Loop() {
64+
var ttd TypedData
65+
err = json.Unmarshal(rawData, &ttd)
66+
require.NoError(b, err)
67+
}
68+
})
69+
}
70+
}
71+
72+
// BenchmarkGetMessageHash is a benchmark function for testing the GetMessageHash function.
73+
func BenchmarkGetMessageHash(b *testing.B) {
74+
addr := "0xdeadbeef"
75+
76+
for key, typedData := range typedDataExamples {
77+
b.Run(key, func(b *testing.B) {
78+
for b.Loop() {
79+
_, err := typedData.GetMessageHash(addr)
80+
if err != nil {
81+
b.Fatal(err)
82+
}
83+
}
84+
})
85+
}
86+
}
87+
88+
// TestMarshalJSON tests the MarshalJSON function. It marshals the TypedData and compares the result
89+
// with the original raw data.
90+
func TestMarshalJSON(t *testing.T) {
91+
for _, filename := range fileNames {
92+
t.Run(filename, func(t *testing.T) {
93+
rawData, err := os.ReadFile(fmt.Sprintf("./testData/%s.json", filename))
94+
require.NoError(t, err)
7095

71-
return
96+
marshaledData, err := json.Marshal(typedDataExamples[filename])
97+
require.NoError(t, err)
98+
99+
require.JSONEq(t, string(rawData), string(marshaledData))
100+
})
101+
}
72102
}
73103

74104
// TestMessageHash tests the GetMessageHash function.
@@ -157,29 +187,6 @@ func TestGetMessageHash(t *testing.T) {
157187
}
158188
}
159189

160-
// BenchmarkGetMessageHash is a benchmark function for testing the GetMessageHash function.
161-
//
162-
// It tests the performance of the GetMessageHash function by running it with different input sizes.
163-
// The input size is determined by the bit length of the address parameter, which is converted from
164-
// a hexadecimal string to a big integer using the HexToBN function from the utils package.
165-
//
166-
// Parameters:
167-
// - b: a testing.B object that provides methods for benchmarking the function
168-
//
169-
// Returns:
170-
//
171-
// none
172-
func BenchmarkGetMessageHash(b *testing.B) {
173-
ttd := BMockTypedData(b)
174-
175-
addr := "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"
176-
b.Run(fmt.Sprintf("input_size_%d", len(addr)), func(b *testing.B) {
177-
result, err := ttd.GetMessageHash(addr)
178-
require.NoError(b, err)
179-
require.NotEmpty(b, result)
180-
})
181-
}
182-
183190
// TestGeneral_GetTypeHash tests the GetTypeHash function.
184191
//
185192
// It tests the GetTypeHash function by calling it with different input values

0 commit comments

Comments
 (0)